root: fix request_id not being logged for actual asgi requests

This commit is contained in:
Jens Langhammer 2021-02-16 19:14:08 +01:00
parent 8bd147b205
commit 61604adf9a
2 changed files with 16 additions and 5 deletions

View file

@ -9,6 +9,7 @@ from django.http import HttpRequest, HttpResponse
SESSION_IMPERSONATE_USER = "authentik_impersonate_user"
SESSION_IMPERSONATE_ORIGINAL_USER = "authentik_impersonate_original_user"
LOCAL = local()
RESPONSE_HEADER_ID = "X-authentik-id"
class ImpersonateMiddleware:
@ -43,7 +44,7 @@ class RequestIDMiddleware:
setattr(request, "request_id", request_id)
LOCAL.authentik = {"request_id": request_id}
response = self.get_response(request)
response["X-authentik-id"] = request.request_id
response[RESPONSE_HEADER_ID] = request.request_id
del LOCAL.authentik["request_id"]
return response

View file

@ -18,6 +18,8 @@ from django.core.asgi import get_asgi_application
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from structlog.stdlib import get_logger
from authentik.core.middleware import RESPONSE_HEADER_ID
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
defuse_stdlib()
@ -67,6 +69,7 @@ class ASGILogger:
status_code: int
start: float
content_length: int
request_id: str
def __init__(self, app: ASGIApp):
self.app = app
@ -75,23 +78,29 @@ class ASGILogger:
self.scope = scope
self.content_length = 0
self.headers = dict(scope.get("headers", []))
self.request_id = ""
async def send_hooked(message: Message) -> None:
"""Hooked send method, which records status code and content-length, and for the final
requests logs it"""
headers = dict(message.get("headers", []))
if "status" in message:
self.status_code = message["status"]
if b"Content-Length" in headers:
self.content_length += int(headers.get(b"Content-Length", b"0"))
if message["type"] == "http.response.start":
response_headers = dict(message["headers"])
self.request_id = response_headers.get(
RESPONSE_HEADER_ID.encode(), b""
).decode()
if message["type"] == "http.response.body" and not message.get(
"more_body", None
"more_body", True
):
runtime = int((time() - self.start) * 1000)
self.log(runtime)
self.log(runtime, request_id=self.request_id)
await send(message)
self.start = time()
@ -111,7 +120,7 @@ class ASGILogger:
# Check if header has multiple values, and use the first one
return client_ip.split(", ")[0]
def log(self, runtime: float):
def log(self, runtime: float, **kwargs):
"""Outpot access logs in a structured format"""
host = self._get_ip()
query_string = ""
@ -125,6 +134,7 @@ class ASGILogger:
status=self.status_code,
size=self.content_length / 1000 if self.content_length > 0 else 0,
runtime=runtime,
**kwargs,
)