root: fix session middleware for websocket connections (#4909)
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
10b7d78825
commit
94f22cffba
|
@ -0,0 +1,34 @@
|
||||||
|
"""ASGI middleware"""
|
||||||
|
from channels.db import database_sync_to_async
|
||||||
|
from channels.sessions import InstanceSessionWrapper as UpstreamInstanceSessionWrapper
|
||||||
|
from channels.sessions import SessionMiddleware as UpstreamSessionMiddleware
|
||||||
|
|
||||||
|
from authentik.root.middleware import SessionMiddleware as HTTPSessionMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceSessionWrapper(UpstreamInstanceSessionWrapper):
|
||||||
|
"""InstanceSessionWrapper which calls the django middleware to decode
|
||||||
|
the session key"""
|
||||||
|
|
||||||
|
async def resolve_session(self):
|
||||||
|
raw_session = self.scope["cookies"].get(self.cookie_name)
|
||||||
|
session_key = HTTPSessionMiddleware.decode_session_key(raw_session)
|
||||||
|
self.scope["session"]._wrapped = await database_sync_to_async(self.session_store)(
|
||||||
|
session_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionMiddleware(UpstreamSessionMiddleware):
|
||||||
|
"""ASGI SessionMiddleware which uses the modified InstanceSessionWrapper
|
||||||
|
wrapper to decode the session key"""
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
"""
|
||||||
|
Instantiate a session wrapper for this scope, resolve the session and
|
||||||
|
call the inner application.
|
||||||
|
"""
|
||||||
|
wrapper = InstanceSessionWrapper(scope, send)
|
||||||
|
|
||||||
|
await wrapper.resolve_session()
|
||||||
|
|
||||||
|
return await self.inner(wrapper.scope, receive, wrapper.send)
|
|
@ -39,16 +39,22 @@ class SessionMiddleware(UpstreamSessionMiddleware):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def process_request(self, request):
|
@staticmethod
|
||||||
session_jwt = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
|
def decode_session_key(key: str) -> str:
|
||||||
|
"""Decode raw session cookie, and parse JWT"""
|
||||||
# We need to support the standard django format of just a session key
|
# We need to support the standard django format of just a session key
|
||||||
# for testing setups, where the session is directly set
|
# for testing setups, where the session is directly set
|
||||||
session_key = session_jwt if settings.TEST else None
|
session_key = key if settings.TEST else None
|
||||||
try:
|
try:
|
||||||
session_payload = decode(session_jwt, SIGNING_HASH, algorithms=["HS256"])
|
session_payload = decode(key, SIGNING_HASH, algorithms=["HS256"])
|
||||||
session_key = session_payload["sid"]
|
session_key = session_payload["sid"]
|
||||||
except (KeyError, PyJWTError):
|
except (KeyError, PyJWTError):
|
||||||
pass
|
pass
|
||||||
|
return session_key
|
||||||
|
|
||||||
|
def process_request(self, request):
|
||||||
|
raw_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
|
||||||
|
session_key = SessionMiddleware.decode_session_key(raw_session)
|
||||||
request.session = self.SessionStore(session_key)
|
request.session = self.SessionStore(session_key)
|
||||||
|
|
||||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
"""root Websocket URLS"""
|
"""root Websocket URLS"""
|
||||||
from channels.auth import AuthMiddlewareStack
|
from channels.auth import AuthMiddleware
|
||||||
|
from channels.sessions import CookieMiddleware
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
|
|
||||||
from authentik.outposts.channels import OutpostConsumer
|
from authentik.outposts.channels import OutpostConsumer
|
||||||
|
from authentik.root.asgi_middleware import SessionMiddleware
|
||||||
from authentik.root.messages.consumer import MessageConsumer
|
from authentik.root.messages.consumer import MessageConsumer
|
||||||
|
|
||||||
websocket_urlpatterns = [
|
websocket_urlpatterns = [
|
||||||
path("ws/outpost/<uuid:pk>/", OutpostConsumer.as_asgi()),
|
path("ws/outpost/<uuid:pk>/", OutpostConsumer.as_asgi()),
|
||||||
path("ws/client/", AuthMiddlewareStack(MessageConsumer.as_asgi())),
|
path(
|
||||||
|
"ws/client/", CookieMiddleware(SessionMiddleware(AuthMiddleware(MessageConsumer.as_asgi())))
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
Reference in New Issue