root: fix concurrency logging issues

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-03-17 18:20:00 +01:00
parent fdbb9803b5
commit 56260cd23f
1 changed files with 10 additions and 24 deletions

View File

@ -46,24 +46,11 @@ ASGI_IP_HEADERS = (
LOGGER = get_logger("authentik.asgi") LOGGER = get_logger("authentik.asgi")
class ASGILoggerMiddleware:
"""Main ASGI Logger middleware, starts an ASGILogger for each request"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
responder = ASGILogger(self.app)
await responder(scope, receive, send)
return
class ASGILogger: class ASGILogger:
"""ASGI Logger, instantiated for each request""" """ASGI Logger, instantiated for each request"""
app: ASGIApp app: ASGIApp
scope: Scope
headers: dict[ByteString, Any] headers: dict[ByteString, Any]
status_code: int status_code: int
@ -75,7 +62,6 @@ class ASGILogger:
self.app = app self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.scope = scope
self.content_length = 0 self.content_length = 0
self.headers = dict(scope.get("headers", [])) self.headers = dict(scope.get("headers", []))
self.request_id = "" self.request_id = ""
@ -100,7 +86,7 @@ class ASGILogger:
"more_body", True "more_body", True
): ):
runtime = int((time() - self.start) * 1000) runtime = int((time() - self.start) * 1000)
self.log(runtime, request_id=self.request_id) self.log(scope, runtime, request_id=self.request_id)
await send(message) await send(message)
self.start = time() self.start = time()
@ -110,27 +96,27 @@ class ASGILogger:
return return
await self.app(scope, receive, send_hooked) await self.app(scope, receive, send_hooked)
def _get_ip(self) -> str: def _get_ip(self, scope: Scope) -> str:
client_ip = None client_ip = None
for header in ASGI_IP_HEADERS: for header in ASGI_IP_HEADERS:
if header in self.headers: if header in self.headers:
client_ip = self.headers[header].decode() client_ip = self.headers[header].decode()
if not client_ip: if not client_ip:
client_ip, _ = self.scope.get("client", ("", 0)) client_ip, _ = scope.get("client", ("", 0))
# Check if header has multiple values, and use the first one # Check if header has multiple values, and use the first one
return client_ip.split(", ")[0] return client_ip.split(", ")[0]
def log(self, runtime: float, **kwargs): def log(self, scope: Scope, runtime: float, **kwargs):
"""Outpot access logs in a structured format""" """Outpot access logs in a structured format"""
host = self._get_ip() host = self._get_ip(scope)
query_string = "" query_string = ""
if self.scope.get("query_string", b"") != b"": if scope.get("query_string", b"") != b"":
query_string = f"?{self.scope.get('query_string').decode()}" query_string = f"?{scope.get('query_string').decode()}"
LOGGER.info( LOGGER.info(
f"{self.scope.get('path', '')}{query_string}", f"{scope.get('path', '')}{query_string}",
host=host, host=host,
method=self.scope.get("method", ""), method=scope.get("method", ""),
scheme=self.scope.get("scheme", ""), scheme=scope.get("scheme", ""),
status=self.status_code, status=self.status_code,
size=self.content_length / 1000 if self.content_length > 0 else 0, size=self.content_length / 1000 if self.content_length > 0 else 0,
runtime=runtime, runtime=runtime,