providers/proxy: improve SLO by backchannel logging out sessions (#7099)

* outposts: add support for provider-specific websocket messages

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* providers/proxy: add custom signal on logout to logout in provider

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-10-09 01:06:52 +02:00 committed by GitHub
parent f60b65c25f
commit 4db365c947
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 134 additions and 7 deletions

View File

@ -27,6 +27,9 @@ class WebsocketMessageInstruction(IntEnum):
# Message sent by us to trigger an Update # Message sent by us to trigger an Update
TRIGGER_UPDATE = 2 TRIGGER_UPDATE = 2
# Provider specific message
PROVIDER_SPECIFIC = 3
@dataclass(slots=True) @dataclass(slots=True)
class WebsocketMessage: class WebsocketMessage:
@ -131,3 +134,14 @@ class OutpostConsumer(AuthJsonConsumer):
self.send_json( self.send_json(
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
) )
def event_provider_specific(self, event):
"""Event handler which can be called by provider-specific
implementations to send specific messages to the outpost"""
self.send_json(
asdict(
WebsocketMessage(
instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event
)
)
)

View File

@ -5,7 +5,6 @@ from socket import gethostname
from typing import Any, Optional from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import yaml
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer from channels.layers import get_channel_layer
from django.core.cache import cache from django.core.cache import cache
@ -16,6 +15,7 @@ from docker.constants import DEFAULT_UNIX_SOCKET
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from yaml import safe_load
from authentik.events.monitored_tasks import ( from authentik.events.monitored_tasks import (
MonitoredTask, MonitoredTask,
@ -279,7 +279,7 @@ def outpost_connection_discovery(self: MonitoredTask):
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
KubernetesServiceConnection.objects.create( KubernetesServiceConnection.objects.create(
name=kubeconfig_local_name, name=kubeconfig_local_name,
kubeconfig=yaml.safe_load(_kubeconfig), kubeconfig=safe_load(_kubeconfig),
) )
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
socket = Path(unix_socket_path) socket = Path(unix_socket_path)

View File

@ -9,3 +9,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig):
label = "authentik_providers_proxy" label = "authentik_providers_proxy"
verbose_name = "authentik Providers.Proxy" verbose_name = "authentik Providers.Proxy"
default = True default = True
def reconcile_load_providers_proxy_signals(self):
"""Load proxy signals"""
self.import_module("authentik.providers.proxy.signals")

View File

@ -0,0 +1,20 @@
"""Proxy provider signals"""
from django.contrib.auth.signals import user_logged_out
from django.db.models.signals import pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.proxy.tasks import proxy_on_logout
@receiver(user_logged_out)
def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to proxy providers"""
proxy_on_logout.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""
proxy_on_logout.delay(instance.session_key)

View File

@ -1,6 +1,9 @@
"""proxy provider tasks""" """proxy provider tasks"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from authentik.outposts.models import Outpost, OutpostState, OutpostType
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
@ -13,3 +16,20 @@ def proxy_set_defaults():
for provider in ProxyProvider.objects.all(): for provider in ProxyProvider.objects.all():
provider.set_oauth_defaults() provider.set_oauth_defaults()
provider.save() provider.save()
@CELERY_APP.task()
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
for state in OutpostState.for_outpost(outpost):
for channel in state.channel_ids:
async_to_sync(layer.send)(
channel,
{
"type": "event.provider.specific",
"sub_type": "logout",
"session_id": session_id,
},
)

View File

@ -15,6 +15,7 @@ entries:
# This mapping is used by the authentik proxy. It passes extra user attributes, # This mapping is used by the authentik proxy. It passes extra user attributes,
# which are used for example for the HTTP-Basic Authentication mapping. # which are used for example for the HTTP-Basic Authentication mapping.
return { return {
"sid": request.http_request.session.session_key,
"ak_proxy": { "ak_proxy": {
"user_attributes": request.user.group_attributes(request), "user_attributes": request.user.group_attributes(request),
"is_superuser": request.user.is_superuser, "is_superuser": request.user.is_superuser,

View File

@ -22,6 +22,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type WSHandler func(ctx context.Context, args map[string]interface{})
const ConfigLogLevel = "log_level" const ConfigLogLevel = "log_level"
// APIController main controller which connects to the authentik api via http and ws // APIController main controller which connects to the authentik api via http and ws
@ -42,6 +44,7 @@ type APIController struct {
lastWsReconnect time.Time lastWsReconnect time.Time
wsIsReconnecting bool wsIsReconnecting bool
wsBackoffMultiplier int wsBackoffMultiplier int
wsHandlers []WSHandler
refreshHandlers []func() refreshHandlers []func()
instanceUUID uuid.UUID instanceUUID uuid.UUID
@ -106,6 +109,7 @@ func NewAPIController(akURL url.URL, token string) *APIController {
reloadOffset: time.Duration(rand.Intn(10)) * time.Second, reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(), instanceUUID: uuid.New(),
Outpost: outpost, Outpost: outpost,
wsHandlers: []WSHandler{},
wsBackoffMultiplier: 1, wsBackoffMultiplier: 1,
refreshHandlers: make([]func(), 0), refreshHandlers: make([]func(), 0),
} }
@ -156,6 +160,10 @@ func (a *APIController) AddRefreshHandler(handler func()) {
a.refreshHandlers = append(a.refreshHandlers, handler) a.refreshHandlers = append(a.refreshHandlers, handler)
} }
func (a *APIController) AddWSHandler(handler WSHandler) {
a.wsHandlers = append(a.wsHandlers, handler)
}
func (a *APIController) OnRefresh() error { func (a *APIController) OnRefresh() error {
// Because we don't know the outpost UUID, we simply do a list and pick the first // Because we don't know the outpost UUID, we simply do a list and pick the first
// The service account this token belongs to should only have access to a single outpost // The service account this token belongs to should only have access to a single outpost

View File

@ -1,6 +1,7 @@
package ak package ak
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
@ -145,6 +146,10 @@ func (ac *APIController) startWSHandler() {
"build": constants.BUILD("tagged"), "build": constants.BUILD("tagged"),
}).SetToCurrentTime() }).SetToCurrentTime()
} }
} else if wsMsg.Instruction == WebsocketInstructionProviderSpecific {
for _, h := range ac.wsHandlers {
h(context.Background(), wsMsg.Args)
}
} }
} }
} }

View File

@ -9,6 +9,8 @@ const (
WebsocketInstructionHello websocketInstruction = 1 WebsocketInstructionHello websocketInstruction = 1
// WebsocketInstructionTriggerUpdate Code received to trigger a config update // WebsocketInstructionTriggerUpdate Code received to trigger a config update
WebsocketInstructionTriggerUpdate websocketInstruction = 2 WebsocketInstructionTriggerUpdate websocketInstruction = 2
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
WebsocketInstructionProviderSpecific websocketInstruction = 3
) )
type websocketMessage struct { type websocketMessage struct {

View File

@ -280,7 +280,9 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
"id_token_hint": []string{cc.RawToken}, "id_token_hint": []string{cc.RawToken},
} }
redirect += "?" + uv.Encode() redirect += "?" + uv.Encode()
err = a.Logout(r.Context(), cc.Sub) err = a.Logout(r.Context(), func(c Claims) bool {
return c.Sub == cc.Sub
})
if err != nil { if err != nil {
a.log.WithError(err).Warning("failed to logout of other sessions") a.log.WithError(err).Warning("failed to logout of other sessions")
} }

View File

@ -11,10 +11,11 @@ type Claims struct {
Exp int `json:"exp"` Exp int `json:"exp"`
Email string `json:"email"` Email string `json:"email"`
Verified bool `json:"email_verified"` Verified bool `json:"email_verified"`
Proxy *ProxyClaims `json:"ak_proxy"`
Name string `json:"name"` Name string `json:"name"`
PreferredUsername string `json:"preferred_username"` PreferredUsername string `json:"preferred_username"`
Groups []string `json:"groups"` Groups []string `json:"groups"`
Sid string `json:"sid"`
Proxy *ProxyClaims `json:"ak_proxy"`
RawToken string RawToken string
} }

View File

@ -88,7 +88,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec {
return cs return cs
} }
func (a *Application) Logout(ctx context.Context, sub string) error { func (a *Application) Logout(ctx context.Context, filter func(c Claims) bool) error {
if _, ok := a.sessions.(*sessions.FilesystemStore); ok { if _, ok := a.sessions.(*sessions.FilesystemStore); ok {
files, err := os.ReadDir(os.TempDir()) files, err := os.ReadDir(os.TempDir())
if err != nil { if err != nil {
@ -118,7 +118,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error {
continue continue
} }
claims := s.Values[constants.SessionClaims].(Claims) claims := s.Values[constants.SessionClaims].(Claims)
if claims.Sub == sub { if filter(claims) {
a.log.WithField("path", fullPath).Trace("deleting session") a.log.WithField("path", fullPath).Trace("deleting session")
err := os.Remove(fullPath) err := os.Remove(fullPath)
if err != nil { if err != nil {
@ -153,7 +153,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error {
continue continue
} }
claims := c.(Claims) claims := c.(Claims)
if claims.Sub == sub { if filter(claims) {
a.log.WithField("key", key).Trace("deleting session") a.log.WithField("key", key).Trace("deleting session")
_, err := client.Del(ctx, key).Result() _, err := client.Del(ctx, key).Result()
if err != nil { if err != nil {

View File

@ -65,6 +65,7 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer {
globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic)
globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing))
rootMux.PathPrefix("/").HandlerFunc(s.Handle) rootMux.PathPrefix("/").HandlerFunc(s.Handle)
ac.AddWSHandler(s.handleWSMessage)
return s return s
} }

View File

@ -0,0 +1,49 @@
package proxyv2
import (
"context"
"github.com/mitchellh/mapstructure"
"goauthentik.io/internal/outpost/proxyv2/application"
)
type WSProviderSubType string
const (
WSProviderSubTypeLogout WSProviderSubType = "logout"
)
type WSProviderMsg struct {
SubType WSProviderSubType `mapstructure:"sub_type"`
SessionID string `mapstructure:"session_id"`
}
func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) {
msg := &WSProviderMsg{}
err := mapstructure.Decode(args, &msg)
if err != nil {
return nil, err
}
return msg, nil
}
func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) {
msg, err := ParseWSProvider(args)
if err != nil {
ps.log.WithError(err).Warning("invalid provider-specific ws message")
return
}
switch msg.SubType {
case WSProviderSubTypeLogout:
for _, p := range ps.apps {
err := p.Logout(ctx, func(c application.Claims) bool {
return c.Sid == msg.SessionID
})
if err != nil {
ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout")
}
}
default:
ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type")
}
}