diff --git a/authentik/api/authentication.py b/authentik/api/authentication.py index a069f4681..5f6b7b1f7 100644 --- a/authentik/api/authentication.py +++ b/authentik/api/authentication.py @@ -9,6 +9,7 @@ from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request from structlog.stdlib import get_logger +from authentik.core.middleware import KEY_AUTH_VIA, LOCAL from authentik.core.models import Token, TokenIntents, User from authentik.outposts.models import Outpost @@ -44,6 +45,7 @@ def bearer_auth(raw_header: bytes) -> Optional[User]: if not user: raise AuthenticationFailed("Token invalid/expired") return user + LOCAL.authentik[KEY_AUTH_VIA] = "api_token" return tokens.first().user @@ -57,7 +59,7 @@ def token_secret_key(value: str) -> Optional[User]: outposts = Outpost.objects.filter(managed=MANAGED_OUTPOST) if not outposts: return None - LOGGER.info("Authenticating via secret_key") + LOCAL.authentik[KEY_AUTH_VIA] = "secret_key" outpost = outposts.first() return outpost.user diff --git a/authentik/core/middleware.py b/authentik/core/middleware.py index 58a821197..d9d1be46e 100644 --- a/authentik/core/middleware.py +++ b/authentik/core/middleware.py @@ -10,6 +10,9 @@ SESSION_IMPERSONATE_USER = "authentik_impersonate_user" SESSION_IMPERSONATE_ORIGINAL_USER = "authentik_impersonate_original_user" LOCAL = local() RESPONSE_HEADER_ID = "X-authentik-id" +KEY_AUTH_VIA = "auth_via" +KEY_USER = "user" +INTERNAL_HEADER_PREFIX = "X-authentik-internal-" class ImpersonateMiddleware: @@ -50,15 +53,17 @@ class RequestIDMiddleware: } response = self.get_response(request) response[RESPONSE_HEADER_ID] = request.request_id - del LOCAL.authentik["request_id"] - del LOCAL.authentik["host"] + if auth_via := LOCAL.authentik.get(KEY_AUTH_VIA, None): + response[INTERNAL_HEADER_PREFIX + KEY_AUTH_VIA] = auth_via + response[INTERNAL_HEADER_PREFIX + KEY_USER] = request.user.username + for key in list(LOCAL.authentik.keys()): + del LOCAL.authentik[key] return response # pylint: disable=unused-argument -def structlog_add_request_id(logger: Logger, method_name: str, event_dict): +def structlog_add_request_id(logger: Logger, method_name: str, event_dict: dict): """If threadlocal has authentik defined, add request_id to log""" if hasattr(LOCAL, "authentik"): - event_dict["request_id"] = LOCAL.authentik.get("request_id", "") - event_dict["host"] = LOCAL.authentik.get("host", "") + event_dict.update(LOCAL.authentik) return event_dict diff --git a/authentik/root/asgi/logger.py b/authentik/root/asgi/logger.py index 1f1568e48..6b5cade99 100644 --- a/authentik/root/asgi/logger.py +++ b/authentik/root/asgi/logger.py @@ -3,7 +3,7 @@ from time import time from structlog.stdlib import get_logger -from authentik.core.middleware import RESPONSE_HEADER_ID +from authentik.core.middleware import INTERNAL_HEADER_PREFIX, RESPONSE_HEADER_ID from authentik.root.asgi.types import ASGIApp, Message, Receive, Scope, Send ASGI_IP_HEADERS = ( @@ -26,6 +26,8 @@ class ASGILogger: content_length = 0 status_code = 0 request_id = "" + # Copy all headers starting with X-authentik-internal + copied_headers = {} location = "" start = time() @@ -45,9 +47,19 @@ class ASGILogger: if message["type"] == "http.response.start": response_headers = dict(message["headers"]) nonlocal request_id + nonlocal copied_headers nonlocal location request_id = response_headers.get(RESPONSE_HEADER_ID.encode(), b"").decode() location = response_headers.get(b"Location", b"").decode() + # Copy all internal headers to log, and remove them from the final response + for header in list(response_headers.keys()): + if not header.decode().startswith(INTERNAL_HEADER_PREFIX): + continue + copied_headers[ + header.decode().replace(INTERNAL_HEADER_PREFIX, "") + ] = response_headers[header].decode() + del response_headers[header] + message["headers"] = list(response_headers.items()) if message["type"] == "http.response.body" and not message.get("more_body", True): nonlocal start @@ -55,6 +67,7 @@ class ASGILogger: kwargs = {"request_id": request_id} if location != "": kwargs["location"] = location + kwargs.update(copied_headers) self.log(scope, runtime, content_length, status_code, **kwargs) await send(message)