root: fix request_id not being logged for actual asgi requests

This commit is contained in:
Jens Langhammer 2021-02-16 19:14:08 +01:00
parent 8bd147b205
commit 61604adf9a
2 changed files with 16 additions and 5 deletions

View file

@ -9,6 +9,7 @@ from django.http import HttpRequest, HttpResponse
SESSION_IMPERSONATE_USER = "authentik_impersonate_user" SESSION_IMPERSONATE_USER = "authentik_impersonate_user"
SESSION_IMPERSONATE_ORIGINAL_USER = "authentik_impersonate_original_user" SESSION_IMPERSONATE_ORIGINAL_USER = "authentik_impersonate_original_user"
LOCAL = local() LOCAL = local()
RESPONSE_HEADER_ID = "X-authentik-id"
class ImpersonateMiddleware: class ImpersonateMiddleware:
@ -43,7 +44,7 @@ class RequestIDMiddleware:
setattr(request, "request_id", request_id) setattr(request, "request_id", request_id)
LOCAL.authentik = {"request_id": request_id} LOCAL.authentik = {"request_id": request_id}
response = self.get_response(request) response = self.get_response(request)
response["X-authentik-id"] = request.request_id response[RESPONSE_HEADER_ID] = request.request_id
del LOCAL.authentik["request_id"] del LOCAL.authentik["request_id"]
return response return response

View file

@ -18,6 +18,8 @@ from django.core.asgi import get_asgi_application
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.middleware import RESPONSE_HEADER_ID
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
defuse_stdlib() defuse_stdlib()
@ -67,6 +69,7 @@ class ASGILogger:
status_code: int status_code: int
start: float start: float
content_length: int content_length: int
request_id: str
def __init__(self, app: ASGIApp): def __init__(self, app: ASGIApp):
self.app = app self.app = app
@ -75,23 +78,29 @@ class ASGILogger:
self.scope = scope self.scope = scope
self.content_length = 0 self.content_length = 0
self.headers = dict(scope.get("headers", [])) self.headers = dict(scope.get("headers", []))
self.request_id = ""
async def send_hooked(message: Message) -> None: async def send_hooked(message: Message) -> None:
"""Hooked send method, which records status code and content-length, and for the final """Hooked send method, which records status code and content-length, and for the final
requests logs it""" requests logs it"""
headers = dict(message.get("headers", [])) headers = dict(message.get("headers", []))
if "status" in message: if "status" in message:
self.status_code = message["status"] self.status_code = message["status"]
if b"Content-Length" in headers: if b"Content-Length" in headers:
self.content_length += int(headers.get(b"Content-Length", b"0")) self.content_length += int(headers.get(b"Content-Length", b"0"))
if message["type"] == "http.response.start":
response_headers = dict(message["headers"])
self.request_id = response_headers.get(
RESPONSE_HEADER_ID.encode(), b""
).decode()
if message["type"] == "http.response.body" and not message.get( if message["type"] == "http.response.body" and not message.get(
"more_body", None "more_body", True
): ):
runtime = int((time() - self.start) * 1000) runtime = int((time() - self.start) * 1000)
self.log(runtime) self.log(runtime, request_id=self.request_id)
await send(message) await send(message)
self.start = time() self.start = time()
@ -111,7 +120,7 @@ class ASGILogger:
# Check if header has multiple values, and use the first one # Check if header has multiple values, and use the first one
return client_ip.split(", ")[0] return client_ip.split(", ")[0]
def log(self, runtime: float): def log(self, runtime: float, **kwargs):
"""Outpot access logs in a structured format""" """Outpot access logs in a structured format"""
host = self._get_ip() host = self._get_ip()
query_string = "" query_string = ""
@ -125,6 +134,7 @@ class ASGILogger:
status=self.status_code, status=self.status_code,
size=self.content_length / 1000 if self.content_length > 0 else 0, size=self.content_length / 1000 if self.content_length > 0 else 0,
runtime=runtime, runtime=runtime,
**kwargs,
) )