diff --git a/authentik/root/asgi_middleware.py b/authentik/root/asgi_middleware.py new file mode 100644 index 000000000..33c5195e2 --- /dev/null +++ b/authentik/root/asgi_middleware.py @@ -0,0 +1,34 @@ +"""ASGI middleware""" +from channels.db import database_sync_to_async +from channels.sessions import InstanceSessionWrapper as UpstreamInstanceSessionWrapper +from channels.sessions import SessionMiddleware as UpstreamSessionMiddleware + +from authentik.root.middleware import SessionMiddleware as HTTPSessionMiddleware + + +class InstanceSessionWrapper(UpstreamInstanceSessionWrapper): + """InstanceSessionWrapper which calls the django middleware to decode + the session key""" + + async def resolve_session(self): + raw_session = self.scope["cookies"].get(self.cookie_name) + session_key = HTTPSessionMiddleware.decode_session_key(raw_session) + self.scope["session"]._wrapped = await database_sync_to_async(self.session_store)( + session_key + ) + + +class SessionMiddleware(UpstreamSessionMiddleware): + """ASGI SessionMiddleware which uses the modified InstanceSessionWrapper + wrapper to decode the session key""" + + async def __call__(self, scope, receive, send): + """ + Instantiate a session wrapper for this scope, resolve the session and + call the inner application. + """ + wrapper = InstanceSessionWrapper(scope, send) + + await wrapper.resolve_session() + + return await self.inner(wrapper.scope, receive, wrapper.send) diff --git a/authentik/root/middleware.py b/authentik/root/middleware.py index 2e9133181..444a45070 100644 --- a/authentik/root/middleware.py +++ b/authentik/root/middleware.py @@ -39,16 +39,22 @@ class SessionMiddleware(UpstreamSessionMiddleware): return True return False - def process_request(self, request): - session_jwt = request.COOKIES.get(settings.SESSION_COOKIE_NAME) + @staticmethod + def decode_session_key(key: str) -> str: + """Decode raw session cookie, and parse JWT""" # We need to support the standard django format of just a session key # for testing setups, where the session is directly set - session_key = session_jwt if settings.TEST else None + session_key = key if settings.TEST else None try: - session_payload = decode(session_jwt, SIGNING_HASH, algorithms=["HS256"]) + session_payload = decode(key, SIGNING_HASH, algorithms=["HS256"]) session_key = session_payload["sid"] except (KeyError, PyJWTError): pass + return session_key + + def process_request(self, request): + raw_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME) + session_key = SessionMiddleware.decode_session_key(raw_session) request.session = self.SessionStore(session_key) def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse: diff --git a/authentik/root/websocket.py b/authentik/root/websocket.py index d53b52a12..d7591ff96 100644 --- a/authentik/root/websocket.py +++ b/authentik/root/websocket.py @@ -1,11 +1,15 @@ """root Websocket URLS""" -from channels.auth import AuthMiddlewareStack +from channels.auth import AuthMiddleware +from channels.sessions import CookieMiddleware from django.urls import path from authentik.outposts.channels import OutpostConsumer +from authentik.root.asgi_middleware import SessionMiddleware from authentik.root.messages.consumer import MessageConsumer websocket_urlpatterns = [ path("ws/outpost//", OutpostConsumer.as_asgi()), - path("ws/client/", AuthMiddlewareStack(MessageConsumer.as_asgi())), + path( + "ws/client/", CookieMiddleware(SessionMiddleware(AuthMiddleware(MessageConsumer.as_asgi()))) + ), ]