root: fix session middleware for websocket connections (#4909)
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
10b7d78825
commit
94f22cffba
|
@ -0,0 +1,34 @@
|
|||
"""ASGI middleware"""
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.sessions import InstanceSessionWrapper as UpstreamInstanceSessionWrapper
|
||||
from channels.sessions import SessionMiddleware as UpstreamSessionMiddleware
|
||||
|
||||
from authentik.root.middleware import SessionMiddleware as HTTPSessionMiddleware
|
||||
|
||||
|
||||
class InstanceSessionWrapper(UpstreamInstanceSessionWrapper):
|
||||
"""InstanceSessionWrapper which calls the django middleware to decode
|
||||
the session key"""
|
||||
|
||||
async def resolve_session(self):
|
||||
raw_session = self.scope["cookies"].get(self.cookie_name)
|
||||
session_key = HTTPSessionMiddleware.decode_session_key(raw_session)
|
||||
self.scope["session"]._wrapped = await database_sync_to_async(self.session_store)(
|
||||
session_key
|
||||
)
|
||||
|
||||
|
||||
class SessionMiddleware(UpstreamSessionMiddleware):
|
||||
"""ASGI SessionMiddleware which uses the modified InstanceSessionWrapper
|
||||
wrapper to decode the session key"""
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""
|
||||
Instantiate a session wrapper for this scope, resolve the session and
|
||||
call the inner application.
|
||||
"""
|
||||
wrapper = InstanceSessionWrapper(scope, send)
|
||||
|
||||
await wrapper.resolve_session()
|
||||
|
||||
return await self.inner(wrapper.scope, receive, wrapper.send)
|
|
@ -39,16 +39,22 @@ class SessionMiddleware(UpstreamSessionMiddleware):
|
|||
return True
|
||||
return False
|
||||
|
||||
def process_request(self, request):
|
||||
session_jwt = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
|
||||
@staticmethod
|
||||
def decode_session_key(key: str) -> str:
|
||||
"""Decode raw session cookie, and parse JWT"""
|
||||
# We need to support the standard django format of just a session key
|
||||
# for testing setups, where the session is directly set
|
||||
session_key = session_jwt if settings.TEST else None
|
||||
session_key = key if settings.TEST else None
|
||||
try:
|
||||
session_payload = decode(session_jwt, SIGNING_HASH, algorithms=["HS256"])
|
||||
session_payload = decode(key, SIGNING_HASH, algorithms=["HS256"])
|
||||
session_key = session_payload["sid"]
|
||||
except (KeyError, PyJWTError):
|
||||
pass
|
||||
return session_key
|
||||
|
||||
def process_request(self, request):
|
||||
raw_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
|
||||
session_key = SessionMiddleware.decode_session_key(raw_session)
|
||||
request.session = self.SessionStore(session_key)
|
||||
|
||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
"""root Websocket URLS"""
|
||||
from channels.auth import AuthMiddlewareStack
|
||||
from channels.auth import AuthMiddleware
|
||||
from channels.sessions import CookieMiddleware
|
||||
from django.urls import path
|
||||
|
||||
from authentik.outposts.channels import OutpostConsumer
|
||||
from authentik.root.asgi_middleware import SessionMiddleware
|
||||
from authentik.root.messages.consumer import MessageConsumer
|
||||
|
||||
websocket_urlpatterns = [
|
||||
path("ws/outpost/<uuid:pk>/", OutpostConsumer.as_asgi()),
|
||||
path("ws/client/", AuthMiddlewareStack(MessageConsumer.as_asgi())),
|
||||
path(
|
||||
"ws/client/", CookieMiddleware(SessionMiddleware(AuthMiddleware(MessageConsumer.as_asgi())))
|
||||
),
|
||||
]
|
||||
|
|
Reference in New Issue