providers/proxy: improve SLO by backchannel logging out sessions (#7099)
* outposts: add support for provider-specific websocket messages Signed-off-by: Jens Langhammer <jens@goauthentik.io> * providers/proxy: add custom signal on logout to logout in provider Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
f60b65c25f
commit
4db365c947
|
@ -27,6 +27,9 @@ class WebsocketMessageInstruction(IntEnum):
|
||||||
# Message sent by us to trigger an Update
|
# Message sent by us to trigger an Update
|
||||||
TRIGGER_UPDATE = 2
|
TRIGGER_UPDATE = 2
|
||||||
|
|
||||||
|
# Provider specific message
|
||||||
|
PROVIDER_SPECIFIC = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class WebsocketMessage:
|
class WebsocketMessage:
|
||||||
|
@ -131,3 +134,14 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
self.send_json(
|
self.send_json(
|
||||||
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
|
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def event_provider_specific(self, event):
|
||||||
|
"""Event handler which can be called by provider-specific
|
||||||
|
implementations to send specific messages to the outpost"""
|
||||||
|
self.send_json(
|
||||||
|
asdict(
|
||||||
|
WebsocketMessage(
|
||||||
|
instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -5,7 +5,6 @@ from socket import gethostname
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import yaml
|
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
from channels.layers import get_channel_layer
|
from channels.layers import get_channel_layer
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
@ -16,6 +15,7 @@ from docker.constants import DEFAULT_UNIX_SOCKET
|
||||||
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
|
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
|
||||||
from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
|
from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
from yaml import safe_load
|
||||||
|
|
||||||
from authentik.events.monitored_tasks import (
|
from authentik.events.monitored_tasks import (
|
||||||
MonitoredTask,
|
MonitoredTask,
|
||||||
|
@ -279,7 +279,7 @@ def outpost_connection_discovery(self: MonitoredTask):
|
||||||
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
|
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
|
||||||
KubernetesServiceConnection.objects.create(
|
KubernetesServiceConnection.objects.create(
|
||||||
name=kubeconfig_local_name,
|
name=kubeconfig_local_name,
|
||||||
kubeconfig=yaml.safe_load(_kubeconfig),
|
kubeconfig=safe_load(_kubeconfig),
|
||||||
)
|
)
|
||||||
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
|
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
|
||||||
socket = Path(unix_socket_path)
|
socket = Path(unix_socket_path)
|
||||||
|
|
|
@ -9,3 +9,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig):
|
||||||
label = "authentik_providers_proxy"
|
label = "authentik_providers_proxy"
|
||||||
verbose_name = "authentik Providers.Proxy"
|
verbose_name = "authentik Providers.Proxy"
|
||||||
default = True
|
default = True
|
||||||
|
|
||||||
|
def reconcile_load_providers_proxy_signals(self):
|
||||||
|
"""Load proxy signals"""
|
||||||
|
self.import_module("authentik.providers.proxy.signals")
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""Proxy provider signals"""
|
||||||
|
from django.contrib.auth.signals import user_logged_out
|
||||||
|
from django.db.models.signals import pre_delete
|
||||||
|
from django.dispatch import receiver
|
||||||
|
from django.http import HttpRequest
|
||||||
|
|
||||||
|
from authentik.core.models import AuthenticatedSession, User
|
||||||
|
from authentik.providers.proxy.tasks import proxy_on_logout
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(user_logged_out)
|
||||||
|
def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_):
|
||||||
|
"""Catch logout by direct logout and forward to proxy providers"""
|
||||||
|
proxy_on_logout.delay(request.session.session_key)
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||||
|
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||||
|
"""Catch logout by expiring sessions being deleted"""
|
||||||
|
proxy_on_logout.delay(instance.session_key)
|
|
@ -1,6 +1,9 @@
|
||||||
"""proxy provider tasks"""
|
"""proxy provider tasks"""
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
from channels.layers import get_channel_layer
|
||||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||||
|
|
||||||
|
from authentik.outposts.models import Outpost, OutpostState, OutpostType
|
||||||
from authentik.providers.proxy.models import ProxyProvider
|
from authentik.providers.proxy.models import ProxyProvider
|
||||||
from authentik.root.celery import CELERY_APP
|
from authentik.root.celery import CELERY_APP
|
||||||
|
|
||||||
|
@ -13,3 +16,20 @@ def proxy_set_defaults():
|
||||||
for provider in ProxyProvider.objects.all():
|
for provider in ProxyProvider.objects.all():
|
||||||
provider.set_oauth_defaults()
|
provider.set_oauth_defaults()
|
||||||
provider.save()
|
provider.save()
|
||||||
|
|
||||||
|
|
||||||
|
@CELERY_APP.task()
|
||||||
|
def proxy_on_logout(session_id: str):
|
||||||
|
"""Update outpost instances connected to a single outpost"""
|
||||||
|
layer = get_channel_layer()
|
||||||
|
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||||
|
for state in OutpostState.for_outpost(outpost):
|
||||||
|
for channel in state.channel_ids:
|
||||||
|
async_to_sync(layer.send)(
|
||||||
|
channel,
|
||||||
|
{
|
||||||
|
"type": "event.provider.specific",
|
||||||
|
"sub_type": "logout",
|
||||||
|
"session_id": session_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -15,6 +15,7 @@ entries:
|
||||||
# This mapping is used by the authentik proxy. It passes extra user attributes,
|
# This mapping is used by the authentik proxy. It passes extra user attributes,
|
||||||
# which are used for example for the HTTP-Basic Authentication mapping.
|
# which are used for example for the HTTP-Basic Authentication mapping.
|
||||||
return {
|
return {
|
||||||
|
"sid": request.http_request.session.session_key,
|
||||||
"ak_proxy": {
|
"ak_proxy": {
|
||||||
"user_attributes": request.user.group_attributes(request),
|
"user_attributes": request.user.group_attributes(request),
|
||||||
"is_superuser": request.user.is_superuser,
|
"is_superuser": request.user.is_superuser,
|
||||||
|
|
|
@ -22,6 +22,8 @@ import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WSHandler func(ctx context.Context, args map[string]interface{})
|
||||||
|
|
||||||
const ConfigLogLevel = "log_level"
|
const ConfigLogLevel = "log_level"
|
||||||
|
|
||||||
// APIController main controller which connects to the authentik api via http and ws
|
// APIController main controller which connects to the authentik api via http and ws
|
||||||
|
@ -42,6 +44,7 @@ type APIController struct {
|
||||||
lastWsReconnect time.Time
|
lastWsReconnect time.Time
|
||||||
wsIsReconnecting bool
|
wsIsReconnecting bool
|
||||||
wsBackoffMultiplier int
|
wsBackoffMultiplier int
|
||||||
|
wsHandlers []WSHandler
|
||||||
refreshHandlers []func()
|
refreshHandlers []func()
|
||||||
|
|
||||||
instanceUUID uuid.UUID
|
instanceUUID uuid.UUID
|
||||||
|
@ -106,6 +109,7 @@ func NewAPIController(akURL url.URL, token string) *APIController {
|
||||||
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
|
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
|
||||||
instanceUUID: uuid.New(),
|
instanceUUID: uuid.New(),
|
||||||
Outpost: outpost,
|
Outpost: outpost,
|
||||||
|
wsHandlers: []WSHandler{},
|
||||||
wsBackoffMultiplier: 1,
|
wsBackoffMultiplier: 1,
|
||||||
refreshHandlers: make([]func(), 0),
|
refreshHandlers: make([]func(), 0),
|
||||||
}
|
}
|
||||||
|
@ -156,6 +160,10 @@ func (a *APIController) AddRefreshHandler(handler func()) {
|
||||||
a.refreshHandlers = append(a.refreshHandlers, handler)
|
a.refreshHandlers = append(a.refreshHandlers, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *APIController) AddWSHandler(handler WSHandler) {
|
||||||
|
a.wsHandlers = append(a.wsHandlers, handler)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *APIController) OnRefresh() error {
|
func (a *APIController) OnRefresh() error {
|
||||||
// Because we don't know the outpost UUID, we simply do a list and pick the first
|
// Because we don't know the outpost UUID, we simply do a list and pick the first
|
||||||
// The service account this token belongs to should only have access to a single outpost
|
// The service account this token belongs to should only have access to a single outpost
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package ak
|
package ak
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -145,6 +146,10 @@ func (ac *APIController) startWSHandler() {
|
||||||
"build": constants.BUILD("tagged"),
|
"build": constants.BUILD("tagged"),
|
||||||
}).SetToCurrentTime()
|
}).SetToCurrentTime()
|
||||||
}
|
}
|
||||||
|
} else if wsMsg.Instruction == WebsocketInstructionProviderSpecific {
|
||||||
|
for _, h := range ac.wsHandlers {
|
||||||
|
h(context.Background(), wsMsg.Args)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,8 @@ const (
|
||||||
WebsocketInstructionHello websocketInstruction = 1
|
WebsocketInstructionHello websocketInstruction = 1
|
||||||
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
|
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
|
||||||
WebsocketInstructionTriggerUpdate websocketInstruction = 2
|
WebsocketInstructionTriggerUpdate websocketInstruction = 2
|
||||||
|
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
|
||||||
|
WebsocketInstructionProviderSpecific websocketInstruction = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
type websocketMessage struct {
|
type websocketMessage struct {
|
||||||
|
|
|
@ -280,7 +280,9 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
|
||||||
"id_token_hint": []string{cc.RawToken},
|
"id_token_hint": []string{cc.RawToken},
|
||||||
}
|
}
|
||||||
redirect += "?" + uv.Encode()
|
redirect += "?" + uv.Encode()
|
||||||
err = a.Logout(r.Context(), cc.Sub)
|
err = a.Logout(r.Context(), func(c Claims) bool {
|
||||||
|
return c.Sub == cc.Sub
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.WithError(err).Warning("failed to logout of other sessions")
|
a.log.WithError(err).Warning("failed to logout of other sessions")
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,10 +11,11 @@ type Claims struct {
|
||||||
Exp int `json:"exp"`
|
Exp int `json:"exp"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Verified bool `json:"email_verified"`
|
Verified bool `json:"email_verified"`
|
||||||
Proxy *ProxyClaims `json:"ak_proxy"`
|
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
PreferredUsername string `json:"preferred_username"`
|
PreferredUsername string `json:"preferred_username"`
|
||||||
Groups []string `json:"groups"`
|
Groups []string `json:"groups"`
|
||||||
|
Sid string `json:"sid"`
|
||||||
|
Proxy *ProxyClaims `json:"ak_proxy"`
|
||||||
|
|
||||||
RawToken string
|
RawToken string
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec {
|
||||||
return cs
|
return cs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Application) Logout(ctx context.Context, sub string) error {
|
func (a *Application) Logout(ctx context.Context, filter func(c Claims) bool) error {
|
||||||
if _, ok := a.sessions.(*sessions.FilesystemStore); ok {
|
if _, ok := a.sessions.(*sessions.FilesystemStore); ok {
|
||||||
files, err := os.ReadDir(os.TempDir())
|
files, err := os.ReadDir(os.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -118,7 +118,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
claims := s.Values[constants.SessionClaims].(Claims)
|
claims := s.Values[constants.SessionClaims].(Claims)
|
||||||
if claims.Sub == sub {
|
if filter(claims) {
|
||||||
a.log.WithField("path", fullPath).Trace("deleting session")
|
a.log.WithField("path", fullPath).Trace("deleting session")
|
||||||
err := os.Remove(fullPath)
|
err := os.Remove(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -153,7 +153,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
claims := c.(Claims)
|
claims := c.(Claims)
|
||||||
if claims.Sub == sub {
|
if filter(claims) {
|
||||||
a.log.WithField("key", key).Trace("deleting session")
|
a.log.WithField("key", key).Trace("deleting session")
|
||||||
_, err := client.Del(ctx, key).Result()
|
_, err := client.Del(ctx, key).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -65,6 +65,7 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer {
|
||||||
globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic)
|
globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic)
|
||||||
globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing))
|
globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing))
|
||||||
rootMux.PathPrefix("/").HandlerFunc(s.Handle)
|
rootMux.PathPrefix("/").HandlerFunc(s.Handle)
|
||||||
|
ac.AddWSHandler(s.handleWSMessage)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package proxyv2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
"goauthentik.io/internal/outpost/proxyv2/application"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WSProviderSubType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
WSProviderSubTypeLogout WSProviderSubType = "logout"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WSProviderMsg struct {
|
||||||
|
SubType WSProviderSubType `mapstructure:"sub_type"`
|
||||||
|
SessionID string `mapstructure:"session_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) {
|
||||||
|
msg := &WSProviderMsg{}
|
||||||
|
err := mapstructure.Decode(args, &msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) {
|
||||||
|
msg, err := ParseWSProvider(args)
|
||||||
|
if err != nil {
|
||||||
|
ps.log.WithError(err).Warning("invalid provider-specific ws message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msg.SubType {
|
||||||
|
case WSProviderSubTypeLogout:
|
||||||
|
for _, p := range ps.apps {
|
||||||
|
err := p.Logout(ctx, func(c application.Claims) bool {
|
||||||
|
return c.Sid == msg.SessionID
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type")
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue