diff --git a/authentik/root/asgi.py b/authentik/root/asgi.py index 4fe92ed23..df17143bb 100644 --- a/authentik/root/asgi.py +++ b/authentik/root/asgi.py @@ -8,7 +8,6 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ """ import typing from time import time -from typing import Any, ByteString import django from asgiref.compatibility import guarantee_single_callable @@ -51,20 +50,15 @@ class ASGILogger: app: ASGIApp - headers: dict[ByteString, Any] - status_code: int start: float - content_length: int - request_id: str def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - self.content_length = 0 - self.headers = dict(scope.get("headers", [])) - self.request_id = "" + content_length = 0 + request_id = "" async def send_hooked(message: Message) -> None: """Hooked send method, which records status code and content-length, and for the final @@ -74,11 +68,13 @@ class ASGILogger: self.status_code = message["status"] if b"Content-Length" in headers: - self.content_length += int(headers.get(b"Content-Length", b"0")) + nonlocal content_length + content_length += int(headers.get(b"Content-Length", b"0")) if message["type"] == "http.response.start": response_headers = dict(message["headers"]) - self.request_id = response_headers.get( + nonlocal request_id + request_id = response_headers.get( RESPONSE_HEADER_ID.encode(), b"" ).decode() @@ -86,7 +82,7 @@ class ASGILogger: "more_body", True ): runtime = int((time() - self.start) * 1000) - self.log(scope, runtime, request_id=self.request_id) + self.log(scope, runtime, content_length, request_id=request_id) await send(message) self.start = time() @@ -98,15 +94,16 @@ class ASGILogger: def _get_ip(self, scope: Scope) -> str: client_ip = None + headers = dict(scope.get("headers", [])) for header in ASGI_IP_HEADERS: - if header in self.headers: - client_ip = self.headers[header].decode() + if header in headers: + client_ip = headers[header].decode() if not client_ip: client_ip, _ = scope.get("client", ("", 0)) # Check if header has multiple values, and use the first one return client_ip.split(", ")[0] - def log(self, scope: Scope, runtime: float, **kwargs): + def log(self, scope: Scope, content_length: int, runtime: float, **kwargs): """Outpot access logs in a structured format""" host = self._get_ip(scope) query_string = "" @@ -118,7 +115,7 @@ class ASGILogger: method=scope.get("method", ""), scheme=scope.get("scheme", ""), status=self.status_code, - size=self.content_length / 1000 if self.content_length > 0 else 0, + size=content_length / 1000 if content_length > 0 else 0, runtime=runtime, **kwargs, )