diff --git a/authentik/core/urls.py b/authentik/core/urls.py index 8dc412465..b45288bfe 100644 --- a/authentik/core/urls.py +++ b/authentik/core/urls.py @@ -13,6 +13,7 @@ from authentik.core.views.interface import FlowInterfaceView, InterfaceView from authentik.core.views.session import EndSessionView from authentik.root.asgi_middleware import SessionMiddleware from authentik.root.messages.consumer import MessageConsumer +from authentik.root.middleware import ChannelsLoggingMiddleware urlpatterns = [ path( @@ -70,7 +71,10 @@ urlpatterns = [ websocket_urlpatterns = [ path( - "ws/client/", CookieMiddleware(SessionMiddleware(AuthMiddleware(MessageConsumer.as_asgi()))) + "ws/client/", + ChannelsLoggingMiddleware( + CookieMiddleware(SessionMiddleware(AuthMiddleware(MessageConsumer.as_asgi()))) + ), ), ] diff --git a/authentik/outposts/urls.py b/authentik/outposts/urls.py index 696fd7ff6..1e3982c09 100644 --- a/authentik/outposts/urls.py +++ b/authentik/outposts/urls.py @@ -2,7 +2,8 @@ from django.urls import path from authentik.outposts.channels import OutpostConsumer +from authentik.root.middleware import ChannelsLoggingMiddleware websocket_urlpatterns = [ - path("ws/outpost//", OutpostConsumer.as_asgi()), + path("ws/outpost//", ChannelsLoggingMiddleware(OutpostConsumer.as_asgi())), ] diff --git a/authentik/root/middleware.py b/authentik/root/middleware.py index 444a45070..a6b0d9c7e 100644 --- a/authentik/root/middleware.py +++ b/authentik/root/middleware.py @@ -1,6 +1,7 @@ """Dynamically set SameSite depending if the upstream connection is TLS or not""" from hashlib import sha512 from time import time +from timeit import default_timer from typing import Callable from django.conf import settings @@ -130,6 +131,28 @@ class SessionMiddleware(UpstreamSessionMiddleware): return response +class ChannelsLoggingMiddleware: + """Logging middleware for channels""" + + def __init__(self, inner): + self.inner = inner + + async def __call__(self, scope, receive, send): + self.log(scope) + return await self.inner(scope, receive, send) + + def log(self, scope: dict, **kwargs): + """Log request""" + headers = dict(scope.get("headers", {})) + LOGGER.info( + scope["path"], + scheme="ws", + remote=scope.get("client", [""])[0], + user_agent=headers.get(b"user-agent", b"").decode(), + **kwargs, + ) + + class LoggingMiddleware: """Logger middleware""" @@ -139,14 +162,14 @@ class LoggingMiddleware: self.get_response = get_response def __call__(self, request: HttpRequest) -> HttpResponse: - start = time() + start = default_timer() response = self.get_response(request) status_code = response.status_code kwargs = { "request_id": request.request_id, } kwargs.update(getattr(response, "ak_context", {})) - self.log(request, status_code, int((time() - start) * 1000), **kwargs) + self.log(request, status_code, int((default_timer() - start) * 1000), **kwargs) return response def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs):