diff --git a/authentik/outposts/channels.py b/authentik/outposts/channels.py index 8b3f978ac..f0b656a47 100644 --- a/authentik/outposts/channels.py +++ b/authentik/outposts/channels.py @@ -27,6 +27,9 @@ class WebsocketMessageInstruction(IntEnum): # Message sent by us to trigger an Update TRIGGER_UPDATE = 2 + # Provider specific message + PROVIDER_SPECIFIC = 3 + @dataclass(slots=True) class WebsocketMessage: @@ -131,3 +134,14 @@ class OutpostConsumer(AuthJsonConsumer): self.send_json( asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) ) + + def event_provider_specific(self, event): + """Event handler which can be called by provider-specific + implementations to send specific messages to the outpost""" + self.send_json( + asdict( + WebsocketMessage( + instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event + ) + ) + ) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 227127352..ddb0d5352 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -5,7 +5,6 @@ from socket import gethostname from typing import Any, Optional from urllib.parse import urlparse -import yaml from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from django.core.cache import cache @@ -16,6 +15,7 @@ from docker.constants import DEFAULT_UNIX_SOCKET from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION from structlog.stdlib import get_logger +from yaml import safe_load from authentik.events.monitored_tasks import ( MonitoredTask, @@ -279,7 +279,7 @@ def outpost_connection_discovery(self: MonitoredTask): with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: KubernetesServiceConnection.objects.create( name=kubeconfig_local_name, - kubeconfig=yaml.safe_load(_kubeconfig), + kubeconfig=safe_load(_kubeconfig), ) unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path socket = Path(unix_socket_path) diff --git a/authentik/providers/proxy/apps.py b/authentik/providers/proxy/apps.py index 5e49fe181..4e1a9a883 100644 --- a/authentik/providers/proxy/apps.py +++ b/authentik/providers/proxy/apps.py @@ -9,3 +9,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): label = "authentik_providers_proxy" verbose_name = "authentik Providers.Proxy" default = True + + def reconcile_load_providers_proxy_signals(self): + """Load proxy signals""" + self.import_module("authentik.providers.proxy.signals") diff --git a/authentik/providers/proxy/signals.py b/authentik/providers/proxy/signals.py new file mode 100644 index 000000000..3e199d3c3 --- /dev/null +++ b/authentik/providers/proxy/signals.py @@ -0,0 +1,20 @@ +"""Proxy provider signals""" +from django.contrib.auth.signals import user_logged_out +from django.db.models.signals import pre_delete +from django.dispatch import receiver +from django.http import HttpRequest + +from authentik.core.models import AuthenticatedSession, User +from authentik.providers.proxy.tasks import proxy_on_logout + + +@receiver(user_logged_out) +def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_): + """Catch logout by direct logout and forward to proxy providers""" + proxy_on_logout.delay(request.session.session_key) + + +@receiver(pre_delete, sender=AuthenticatedSession) +def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): + """Catch logout by expiring sessions being deleted""" + proxy_on_logout.delay(instance.session_key) diff --git a/authentik/providers/proxy/tasks.py b/authentik/providers/proxy/tasks.py index a5a4dc45f..630b0d186 100644 --- a/authentik/providers/proxy/tasks.py +++ b/authentik/providers/proxy/tasks.py @@ -1,6 +1,9 @@ """proxy provider tasks""" +from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer from django.db import DatabaseError, InternalError, ProgrammingError +from authentik.outposts.models import Outpost, OutpostState, OutpostType from authentik.providers.proxy.models import ProxyProvider from authentik.root.celery import CELERY_APP @@ -13,3 +16,20 @@ def proxy_set_defaults(): for provider in ProxyProvider.objects.all(): provider.set_oauth_defaults() provider.save() + + +@CELERY_APP.task() +def proxy_on_logout(session_id: str): + """Update outpost instances connected to a single outpost""" + layer = get_channel_layer() + for outpost in Outpost.objects.filter(type=OutpostType.PROXY): + for state in OutpostState.for_outpost(outpost): + for channel in state.channel_ids: + async_to_sync(layer.send)( + channel, + { + "type": "event.provider.specific", + "sub_type": "logout", + "session_id": session_id, + }, + ) diff --git a/blueprints/system/providers-proxy.yaml b/blueprints/system/providers-proxy.yaml index 1214d157d..0086645a8 100644 --- a/blueprints/system/providers-proxy.yaml +++ b/blueprints/system/providers-proxy.yaml @@ -15,6 +15,7 @@ entries: # This mapping is used by the authentik proxy. It passes extra user attributes, # which are used for example for the HTTP-Basic Authentication mapping. return { + "sid": request.http_request.session.session_key, "ak_proxy": { "user_attributes": request.user.group_attributes(request), "is_superuser": request.user.is_superuser, diff --git a/internal/outpost/ak/api.go b/internal/outpost/ak/api.go index cfdb41dd0..f6003c02f 100644 --- a/internal/outpost/ak/api.go +++ b/internal/outpost/ak/api.go @@ -22,6 +22,8 @@ import ( log "github.com/sirupsen/logrus" ) +type WSHandler func(ctx context.Context, args map[string]interface{}) + const ConfigLogLevel = "log_level" // APIController main controller which connects to the authentik api via http and ws @@ -42,6 +44,7 @@ type APIController struct { lastWsReconnect time.Time wsIsReconnecting bool wsBackoffMultiplier int + wsHandlers []WSHandler refreshHandlers []func() instanceUUID uuid.UUID @@ -106,6 +109,7 @@ func NewAPIController(akURL url.URL, token string) *APIController { reloadOffset: time.Duration(rand.Intn(10)) * time.Second, instanceUUID: uuid.New(), Outpost: outpost, + wsHandlers: []WSHandler{}, wsBackoffMultiplier: 1, refreshHandlers: make([]func(), 0), } @@ -156,6 +160,10 @@ func (a *APIController) AddRefreshHandler(handler func()) { a.refreshHandlers = append(a.refreshHandlers, handler) } +func (a *APIController) AddWSHandler(handler WSHandler) { + a.wsHandlers = append(a.wsHandlers, handler) +} + func (a *APIController) OnRefresh() error { // Because we don't know the outpost UUID, we simply do a list and pick the first // The service account this token belongs to should only have access to a single outpost diff --git a/internal/outpost/ak/api_ws.go b/internal/outpost/ak/api_ws.go index 681b26fa4..24c5099f4 100644 --- a/internal/outpost/ak/api_ws.go +++ b/internal/outpost/ak/api_ws.go @@ -1,6 +1,7 @@ package ak import ( + "context" "crypto/tls" "fmt" "net/http" @@ -145,6 +146,10 @@ func (ac *APIController) startWSHandler() { "build": constants.BUILD("tagged"), }).SetToCurrentTime() } + } else if wsMsg.Instruction == WebsocketInstructionProviderSpecific { + for _, h := range ac.wsHandlers { + h(context.Background(), wsMsg.Args) + } } } } diff --git a/internal/outpost/ak/api_ws_msg.go b/internal/outpost/ak/api_ws_msg.go index f1f2e3aa8..cedecb93d 100644 --- a/internal/outpost/ak/api_ws_msg.go +++ b/internal/outpost/ak/api_ws_msg.go @@ -9,6 +9,8 @@ const ( WebsocketInstructionHello websocketInstruction = 1 // WebsocketInstructionTriggerUpdate Code received to trigger a config update WebsocketInstructionTriggerUpdate websocketInstruction = 2 + // WebsocketInstructionProviderSpecific Code received to trigger some provider specific function + WebsocketInstructionProviderSpecific websocketInstruction = 3 ) type websocketMessage struct { diff --git a/internal/outpost/proxyv2/application/application.go b/internal/outpost/proxyv2/application/application.go index 657bcbec7..eae4c6774 100644 --- a/internal/outpost/proxyv2/application/application.go +++ b/internal/outpost/proxyv2/application/application.go @@ -280,7 +280,9 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) { "id_token_hint": []string{cc.RawToken}, } redirect += "?" + uv.Encode() - err = a.Logout(r.Context(), cc.Sub) + err = a.Logout(r.Context(), func(c Claims) bool { + return c.Sub == cc.Sub + }) if err != nil { a.log.WithError(err).Warning("failed to logout of other sessions") } diff --git a/internal/outpost/proxyv2/application/claims.go b/internal/outpost/proxyv2/application/claims.go index bd34e1309..32f4d26eb 100644 --- a/internal/outpost/proxyv2/application/claims.go +++ b/internal/outpost/proxyv2/application/claims.go @@ -11,10 +11,11 @@ type Claims struct { Exp int `json:"exp"` Email string `json:"email"` Verified bool `json:"email_verified"` - Proxy *ProxyClaims `json:"ak_proxy"` Name string `json:"name"` PreferredUsername string `json:"preferred_username"` Groups []string `json:"groups"` + Sid string `json:"sid"` + Proxy *ProxyClaims `json:"ak_proxy"` RawToken string } diff --git a/internal/outpost/proxyv2/application/session.go b/internal/outpost/proxyv2/application/session.go index 739b23e84..55d2bbb46 100644 --- a/internal/outpost/proxyv2/application/session.go +++ b/internal/outpost/proxyv2/application/session.go @@ -88,7 +88,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec { return cs } -func (a *Application) Logout(ctx context.Context, sub string) error { +func (a *Application) Logout(ctx context.Context, filter func(c Claims) bool) error { if _, ok := a.sessions.(*sessions.FilesystemStore); ok { files, err := os.ReadDir(os.TempDir()) if err != nil { @@ -118,7 +118,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error { continue } claims := s.Values[constants.SessionClaims].(Claims) - if claims.Sub == sub { + if filter(claims) { a.log.WithField("path", fullPath).Trace("deleting session") err := os.Remove(fullPath) if err != nil { @@ -153,7 +153,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error { continue } claims := c.(Claims) - if claims.Sub == sub { + if filter(claims) { a.log.WithField("key", key).Trace("deleting session") _, err := client.Del(ctx, key).Result() if err != nil { diff --git a/internal/outpost/proxyv2/proxyv2.go b/internal/outpost/proxyv2/proxyv2.go index 154f79e34..70364957f 100644 --- a/internal/outpost/proxyv2/proxyv2.go +++ b/internal/outpost/proxyv2/proxyv2.go @@ -65,6 +65,7 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer { globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) rootMux.PathPrefix("/").HandlerFunc(s.Handle) + ac.AddWSHandler(s.handleWSMessage) return s } diff --git a/internal/outpost/proxyv2/ws.go b/internal/outpost/proxyv2/ws.go new file mode 100644 index 000000000..b75ba50fd --- /dev/null +++ b/internal/outpost/proxyv2/ws.go @@ -0,0 +1,49 @@ +package proxyv2 + +import ( + "context" + + "github.com/mitchellh/mapstructure" + "goauthentik.io/internal/outpost/proxyv2/application" +) + +type WSProviderSubType string + +const ( + WSProviderSubTypeLogout WSProviderSubType = "logout" +) + +type WSProviderMsg struct { + SubType WSProviderSubType `mapstructure:"sub_type"` + SessionID string `mapstructure:"session_id"` +} + +func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) { + msg := &WSProviderMsg{} + err := mapstructure.Decode(args, &msg) + if err != nil { + return nil, err + } + return msg, nil +} + +func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) { + msg, err := ParseWSProvider(args) + if err != nil { + ps.log.WithError(err).Warning("invalid provider-specific ws message") + return + } + switch msg.SubType { + case WSProviderSubTypeLogout: + for _, p := range ps.apps { + err := p.Logout(ctx, func(c application.Claims) bool { + return c.Sid == msg.SessionID + }) + if err != nil { + ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout") + } + } + default: + ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type") + } +}