root: fix session middleware for websocket connections (#4909)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-03-12 16:47:19 +01:00 committed by GitHub
parent 10b7d78825
commit 94f22cffba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 6 deletions

View File

@ -0,0 +1,34 @@
"""ASGI middleware"""
from channels.db import database_sync_to_async
from channels.sessions import InstanceSessionWrapper as UpstreamInstanceSessionWrapper
from channels.sessions import SessionMiddleware as UpstreamSessionMiddleware
from authentik.root.middleware import SessionMiddleware as HTTPSessionMiddleware
class InstanceSessionWrapper(UpstreamInstanceSessionWrapper):
"""InstanceSessionWrapper which calls the django middleware to decode
the session key"""
async def resolve_session(self):
raw_session = self.scope["cookies"].get(self.cookie_name)
session_key = HTTPSessionMiddleware.decode_session_key(raw_session)
self.scope["session"]._wrapped = await database_sync_to_async(self.session_store)(
session_key
)
class SessionMiddleware(UpstreamSessionMiddleware):
"""ASGI SessionMiddleware which uses the modified InstanceSessionWrapper
wrapper to decode the session key"""
async def __call__(self, scope, receive, send):
"""
Instantiate a session wrapper for this scope, resolve the session and
call the inner application.
"""
wrapper = InstanceSessionWrapper(scope, send)
await wrapper.resolve_session()
return await self.inner(wrapper.scope, receive, wrapper.send)

View File

@ -39,16 +39,22 @@ class SessionMiddleware(UpstreamSessionMiddleware):
return True
return False
def process_request(self, request):
session_jwt = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
@staticmethod
def decode_session_key(key: str) -> str:
"""Decode raw session cookie, and parse JWT"""
# We need to support the standard django format of just a session key
# for testing setups, where the session is directly set
session_key = session_jwt if settings.TEST else None
session_key = key if settings.TEST else None
try:
session_payload = decode(session_jwt, SIGNING_HASH, algorithms=["HS256"])
session_payload = decode(key, SIGNING_HASH, algorithms=["HS256"])
session_key = session_payload["sid"]
except (KeyError, PyJWTError):
pass
return session_key
def process_request(self, request):
raw_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
session_key = SessionMiddleware.decode_session_key(raw_session)
request.session = self.SessionStore(session_key)
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:

View File

@ -1,11 +1,15 @@
"""root Websocket URLS"""
from channels.auth import AuthMiddlewareStack
from channels.auth import AuthMiddleware
from channels.sessions import CookieMiddleware
from django.urls import path
from authentik.outposts.channels import OutpostConsumer
from authentik.root.asgi_middleware import SessionMiddleware
from authentik.root.messages.consumer import MessageConsumer
websocket_urlpatterns = [
path("ws/outpost/<uuid:pk>/", OutpostConsumer.as_asgi()),
path("ws/client/", AuthMiddlewareStack(MessageConsumer.as_asgi())),
path(
"ws/client/", CookieMiddleware(SessionMiddleware(AuthMiddleware(MessageConsumer.as_asgi())))
),
]