From 2c6d82593ea26d4da704b3f7db9f3e6eac076586 Mon Sep 17 00:00:00 2001 From: Jens L Date: Tue, 31 May 2022 21:53:23 +0200 Subject: [PATCH] root: cleanup session keys to use common format (#3003) cleanup session keys to use common format Signed-off-by: Jens Langhammer --- authentik/core/api/users.py | 11 ++-- authentik/core/middleware.py | 10 ++-- authentik/core/views/impersonate.py | 19 ++++--- authentik/events/models.py | 13 +++-- authentik/flows/stage.py | 3 ++ authentik/flows/views/executor.py | 16 +++--- authentik/providers/oauth2/views/authorize.py | 6 +-- .../saml/tests/test_auth_n_request.py | 4 +- authentik/providers/saml/views/flows.py | 5 +- authentik/sources/oauth/clients/oauth2.py | 6 +-- authentik/sources/oauth/types/twitter.py | 6 +-- authentik/sources/plex/plex.py | 2 - authentik/sources/saml/processors/request.py | 4 +- authentik/sources/saml/processors/response.py | 6 +-- authentik/stages/authenticator_duo/stage.py | 8 ++- authentik/stages/authenticator_sms/stage.py | 16 +++--- authentik/stages/authenticator_sms/tests.py | 4 +- .../authenticator_validate/challenge.py | 13 ++--- .../stages/authenticator_validate/stage.py | 50 +++++++++++-------- .../tests/test_stage.py | 4 +- .../stages/authenticator_webauthn/stage.py | 11 ++-- authentik/stages/password/stage.py | 12 ++--- authentik/stages/user_write/stage.py | 4 +- 23 files changed, 132 insertions(+), 101 deletions(-) diff --git a/authentik/core/api/users.py b/authentik/core/api/users.py index 4f42788c1..921f6e0f8 100644 --- a/authentik/core/api/users.py +++ b/authentik/core/api/users.py @@ -43,7 +43,10 @@ from authentik.api.decorators import permission_required from authentik.core.api.groups import GroupSerializer from authentik.core.api.used_by import UsedByMixin from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict -from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER +from authentik.core.middleware import ( + SESSION_KEY_IMPERSONATE_ORIGINAL_USER, + SESSION_KEY_IMPERSONATE_USER, +) from authentik.core.models import ( USER_ATTRIBUTE_SA, USER_ATTRIBUTE_TOKEN_EXPIRING, @@ -336,9 +339,9 @@ class UserViewSet(UsedByMixin, ModelViewSet): serializer = SessionUserSerializer( data={"user": UserSelfSerializer(instance=request.user, context=context).data} ) - if SESSION_IMPERSONATE_USER in request._request.session: + if SESSION_KEY_IMPERSONATE_USER in request._request.session: serializer.initial_data["original"] = UserSelfSerializer( - instance=request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER], + instance=request._request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER], context=context, ).data self.request.session.save() @@ -368,7 +371,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): except (ValidationError, IntegrityError) as exc: LOGGER.debug("Failed to set password", exc=exc) return Response(status=400) - if user.pk == request.user.pk and SESSION_IMPERSONATE_USER not in self.request.session: + if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session: LOGGER.debug("Updating session hash after password change") update_session_auth_hash(self.request, user) return Response(status=204) diff --git a/authentik/core/middleware.py b/authentik/core/middleware.py index 11834dbed..e4d915f80 100644 --- a/authentik/core/middleware.py +++ b/authentik/core/middleware.py @@ -7,8 +7,8 @@ from uuid import uuid4 from django.http import HttpRequest, HttpResponse from sentry_sdk.api import set_tag -SESSION_IMPERSONATE_USER = "authentik_impersonate_user" -SESSION_IMPERSONATE_ORIGINAL_USER = "authentik_impersonate_original_user" +SESSION_KEY_IMPERSONATE_USER = "authentik/impersonate/user" +SESSION_KEY_IMPERSONATE_ORIGINAL_USER = "authentik/impersonate/original_user" LOCAL = local() RESPONSE_HEADER_ID = "X-authentik-id" KEY_AUTH_VIA = "auth_via" @@ -25,10 +25,10 @@ class ImpersonateMiddleware: def __call__(self, request: HttpRequest) -> HttpResponse: # No permission checks are done here, they need to be checked before - # SESSION_IMPERSONATE_USER is set. + # SESSION_KEY_IMPERSONATE_USER is set. - if SESSION_IMPERSONATE_USER in request.session: - request.user = request.session[SESSION_IMPERSONATE_USER] + if SESSION_KEY_IMPERSONATE_USER in request.session: + request.user = request.session[SESSION_KEY_IMPERSONATE_USER] # Ensure that the user is active, otherwise nothing will work request.user.is_active = True diff --git a/authentik/core/views/impersonate.py b/authentik/core/views/impersonate.py index 6a0dcaff3..c19a47f62 100644 --- a/authentik/core/views/impersonate.py +++ b/authentik/core/views/impersonate.py @@ -5,7 +5,10 @@ from django.shortcuts import get_object_or_404, redirect from django.views import View from structlog.stdlib import get_logger -from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER +from authentik.core.middleware import ( + SESSION_KEY_IMPERSONATE_ORIGINAL_USER, + SESSION_KEY_IMPERSONATE_USER, +) from authentik.core.models import User from authentik.events.models import Event, EventAction from authentik.lib.config import CONFIG @@ -27,8 +30,8 @@ class ImpersonateInitView(View): user_to_be = get_object_or_404(User, pk=user_id) - request.session[SESSION_IMPERSONATE_ORIGINAL_USER] = request.user - request.session[SESSION_IMPERSONATE_USER] = user_to_be + request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER] = request.user + request.session[SESSION_KEY_IMPERSONATE_USER] = user_to_be Event.new(EventAction.IMPERSONATION_STARTED).from_http(request, user_to_be) @@ -41,16 +44,16 @@ class ImpersonateEndView(View): def get(self, request: HttpRequest) -> HttpResponse: """End Impersonation handler""" if ( - SESSION_IMPERSONATE_USER not in request.session - or SESSION_IMPERSONATE_ORIGINAL_USER not in request.session + SESSION_KEY_IMPERSONATE_USER not in request.session + or SESSION_KEY_IMPERSONATE_ORIGINAL_USER not in request.session ): LOGGER.debug("Can't end impersonation", user=request.user) return redirect("authentik_core:if-user") - original_user = request.session[SESSION_IMPERSONATE_ORIGINAL_USER] + original_user = request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER] - del request.session[SESSION_IMPERSONATE_USER] - del request.session[SESSION_IMPERSONATE_ORIGINAL_USER] + del request.session[SESSION_KEY_IMPERSONATE_USER] + del request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER] Event.new(EventAction.IMPERSONATION_ENDED).from_http(request, original_user) diff --git a/authentik/events/models.py b/authentik/events/models.py index 2c77ee129..d1bdc178f 100644 --- a/authentik/events/models.py +++ b/authentik/events/models.py @@ -23,7 +23,10 @@ from requests import RequestException from structlog.stdlib import get_logger from authentik import __version__ -from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER +from authentik.core.middleware import ( + SESSION_KEY_IMPERSONATE_ORIGINAL_USER, + SESSION_KEY_IMPERSONATE_USER, +) from authentik.core.models import ExpiringModel, Group, PropertyMapping, User from authentik.events.geo import GEOIP_READER from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict @@ -233,15 +236,15 @@ class Event(ExpiringModel): if hasattr(request, "user"): original_user = None if hasattr(request, "session"): - original_user = request.session.get(SESSION_IMPERSONATE_ORIGINAL_USER, None) + original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None) self.user = get_user(request.user, original_user) if user: self.user = get_user(user) # Check if we're currently impersonating, and add that user if hasattr(request, "session"): - if SESSION_IMPERSONATE_ORIGINAL_USER in request.session: - self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER]) - self.user["on_behalf_of"] = get_user(request.session[SESSION_IMPERSONATE_USER]) + if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session: + self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]) + self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER]) # User 255.255.255.255 as fallback if IP cannot be determined self.client_ip = get_client_ip(request) # Apply GeoIP Data, when enabled diff --git a/authentik/flows/stage.py b/authentik/flows/stage.py index 8e33e3503..970c4792c 100644 --- a/authentik/flows/stage.py +++ b/authentik/flows/stage.py @@ -60,6 +60,9 @@ class StageView(View): return self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] return self.request.user + def cleanup(self): + """Cleanup session""" + class ChallengeStageView(StageView): """Stage view which response with a challenge""" diff --git a/authentik/flows/views/executor.py b/authentik/flows/views/executor.py index f5881b5fc..ea08b2a39 100644 --- a/authentik/flows/views/executor.py +++ b/authentik/flows/views/executor.py @@ -49,7 +49,7 @@ from authentik.flows.planner import ( FlowPlan, FlowPlanner, ) -from authentik.flows.stage import AccessDeniedChallengeView +from authentik.flows.stage import AccessDeniedChallengeView, StageView from authentik.lib.sentry import SentryIgnoredException from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.reflection import all_subclasses, class_to_path @@ -59,11 +59,11 @@ from authentik.tenants.models import Tenant LOGGER = get_logger() # Argument used to redirect user after login NEXT_ARG_NAME = "next" -SESSION_KEY_PLAN = "authentik_flows_plan" -SESSION_KEY_APPLICATION_PRE = "authentik_flows_application_pre" -SESSION_KEY_GET = "authentik_flows_get" -SESSION_KEY_POST = "authentik_flows_post" -SESSION_KEY_HISTORY = "authentik_flows_history" +SESSION_KEY_PLAN = "authentik/flows/plan" +SESSION_KEY_APPLICATION_PRE = "authentik/flows/application_pre" +SESSION_KEY_GET = "authentik/flows/get" +SESSION_KEY_POST = "authentik/flows/post" +SESSION_KEY_HISTORY = "authentik/flows/history" QS_KEY_TOKEN = "flow_token" # nosec @@ -380,6 +380,8 @@ class FlowExecutorView(APIView): "f(exec): Stage ok", stage_class=class_to_path(self.current_stage_view.__class__), ) + if isinstance(self.current_stage_view, StageView): + self.current_stage_view.cleanup() self.request.session.get(SESSION_KEY_HISTORY, []).append(deepcopy(self.plan)) self.plan.pop() self.request.session[SESSION_KEY_PLAN] = self.plan @@ -416,6 +418,8 @@ class FlowExecutorView(APIView): SESSION_KEY_APPLICATION_PRE, SESSION_KEY_PLAN, SESSION_KEY_GET, + # We might need the initial POST payloads for later requests + # SESSION_KEY_POST, # We don't delete the history on purpose, as a user might # still be inspecting it. # It's only deleted on a fresh executions diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index fa527359a..634b1abc2 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -69,7 +69,7 @@ from authentik.stages.user_login.stage import USER_LOGIN_AUTHENTICATED LOGGER = get_logger() PLAN_CONTEXT_PARAMS = "params" -SESSION_NEEDS_LOGIN = "authentik_oauth2_needs_login" +SESSION_KEY_NEEDS_LOGIN = "authentik/providers/oauth2/needs_login" ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN} @@ -326,13 +326,13 @@ class AuthorizationFlowInitView(PolicyAccessView): # If prompt=login, we need to re-authenticate the user regardless if ( PROMPT_LOGIN in self.params.prompt - and SESSION_NEEDS_LOGIN not in self.request.session + and SESSION_KEY_NEEDS_LOGIN not in self.request.session # To prevent the user from having to double login when prompt is set to login # and the user has just signed it. This session variable is set in the UserLoginStage # and is (quite hackily) removed from the session in applications's API's List method and USER_LOGIN_AUTHENTICATED not in self.request.session ): - self.request.session[SESSION_NEEDS_LOGIN] = True + self.request.session[SESSION_KEY_NEEDS_LOGIN] = True return self.handle_no_permission() # Regardless, we start the planner and return to it planner = FlowPlanner(self.provider.authorization_flow) diff --git a/authentik/providers/saml/tests/test_auth_n_request.py b/authentik/providers/saml/tests/test_auth_n_request.py index 92d457b61..0d8250707 100644 --- a/authentik/providers/saml/tests/test_auth_n_request.py +++ b/authentik/providers/saml/tests/test_auth_n_request.py @@ -19,7 +19,7 @@ from authentik.sources.saml.processors.constants import ( SAML_NAME_ID_FORMAT_EMAIL, SAML_NAME_ID_FORMAT_UNSPECIFIED, ) -from authentik.sources.saml.processors.request import SESSION_REQUEST_ID, RequestProcessor +from authentik.sources.saml.processors.request import SESSION_KEY_REQUEST_ID, RequestProcessor from authentik.sources.saml.processors.response import ResponseProcessor POST_REQUEST = ( @@ -142,7 +142,7 @@ class TestAuthNRequest(TestCase): request = request_proc.build_auth_n() # change the request ID - http_request.session[SESSION_REQUEST_ID] = "test" + http_request.session[SESSION_KEY_REQUEST_ID] = "test" http_request.session.save() # To get an assertion we need a parsed request (parsed by provider) diff --git a/authentik/providers/saml/views/flows.py b/authentik/providers/saml/views/flows.py index bf489fcd4..b07261318 100644 --- a/authentik/providers/saml/views/flows.py +++ b/authentik/providers/saml/views/flows.py @@ -34,7 +34,7 @@ REQUEST_KEY_SAML_SIG_ALG = "SigAlg" REQUEST_KEY_SAML_RESPONSE = "SAMLResponse" REQUEST_KEY_RELAY_STATE = "RelayState" -SESSION_KEY_AUTH_N_REQUEST = "authn_request" +SESSION_KEY_AUTH_N_REQUEST = "authentik/providers/saml/authn_request" # This View doesn't have a URL on purpose, as its called by the FlowExecutor class SAMLFlowFinalView(ChallengeStageView): @@ -106,3 +106,6 @@ class SAMLFlowFinalView(ChallengeStageView): def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: # We'll never get here since the challenge redirects to the SP return HttpResponseBadRequest() + + def cleanup(self): + self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST, None) diff --git a/authentik/sources/oauth/clients/oauth2.py b/authentik/sources/oauth/clients/oauth2.py index 2604bab88..d0af2e876 100644 --- a/authentik/sources/oauth/clients/oauth2.py +++ b/authentik/sources/oauth/clients/oauth2.py @@ -11,7 +11,7 @@ from structlog.stdlib import get_logger from authentik.sources.oauth.clients.base import BaseOAuthClient LOGGER = get_logger() -SESSION_OAUTH_PKCE = "oauth_pkce" +SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce" class OAuth2Client(BaseOAuthClient): @@ -70,8 +70,8 @@ class OAuth2Client(BaseOAuthClient): "code": code, "grant_type": "authorization_code", } - if SESSION_OAUTH_PKCE in self.request.session: - args["code_verifier"] = self.request.session[SESSION_OAUTH_PKCE] + if SESSION_KEY_OAUTH_PKCE in self.request.session: + args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE] try: access_token_url = self.source.type.access_token_url or "" if self.source.type.urls_customizable and self.source.access_token_url: diff --git a/authentik/sources/oauth/types/twitter.py b/authentik/sources/oauth/types/twitter.py index ae2271bdf..7097690ef 100644 --- a/authentik/sources/oauth/types/twitter.py +++ b/authentik/sources/oauth/types/twitter.py @@ -2,7 +2,7 @@ from typing import Any from authentik.lib.generators import generate_id -from authentik.sources.oauth.clients.oauth2 import SESSION_OAUTH_PKCE +from authentik.sources.oauth.clients.oauth2 import SESSION_KEY_OAUTH_PKCE from authentik.sources.oauth.types.azure_ad import AzureADClient from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback @@ -13,10 +13,10 @@ class TwitterOAuthRedirect(OAuthRedirect): """Twitter OAuth2 Redirect""" def get_additional_parameters(self, source): # pragma: no cover - self.request.session[SESSION_OAUTH_PKCE] = generate_id() + self.request.session[SESSION_KEY_OAUTH_PKCE] = generate_id() return { "scope": ["users.read", "tweet.read"], - "code_challenge": self.request.session[SESSION_OAUTH_PKCE], + "code_challenge": self.request.session[SESSION_KEY_OAUTH_PKCE], "code_challenge_method": "plain", } diff --git a/authentik/sources/plex/plex.py b/authentik/sources/plex/plex.py index 29b7c622a..69f50aaf4 100644 --- a/authentik/sources/plex/plex.py +++ b/authentik/sources/plex/plex.py @@ -11,8 +11,6 @@ from authentik.lib.utils.http import get_http_session from authentik.sources.plex.models import PlexSource, PlexSourceConnection LOGGER = get_logger() -SESSION_ID_KEY = "PLEX_ID" -SESSION_CODE_KEY = "PLEX_CODE" class PlexAuth: diff --git a/authentik/sources/saml/processors/request.py b/authentik/sources/saml/processors/request.py index f7fe9cc10..c7a81029c 100644 --- a/authentik/sources/saml/processors/request.py +++ b/authentik/sources/saml/processors/request.py @@ -19,7 +19,7 @@ from authentik.sources.saml.processors.constants import ( SIGN_ALGORITHM_TRANSFORM_MAP, ) -SESSION_REQUEST_ID = "authentik_source_saml_request_id" +SESSION_KEY_REQUEST_ID = "authentik/sources/saml/request_id" class RequestProcessor: @@ -38,7 +38,7 @@ class RequestProcessor: self.http_request = request self.relay_state = relay_state self.request_id = get_random_id() - self.http_request.session[SESSION_REQUEST_ID] = self.request_id + self.http_request.session[SESSION_KEY_REQUEST_ID] = self.request_id self.issue_instant = get_time_string() def get_issuer(self) -> Element: diff --git a/authentik/sources/saml/processors/response.py b/authentik/sources/saml/processors/response.py index 8859bbc61..08aa6c655 100644 --- a/authentik/sources/saml/processors/response.py +++ b/authentik/sources/saml/processors/response.py @@ -45,7 +45,7 @@ from authentik.sources.saml.processors.constants import ( SAML_NAME_ID_FORMAT_WINDOWS, SAML_NAME_ID_FORMAT_X509, ) -from authentik.sources.saml.processors.request import SESSION_REQUEST_ID +from authentik.sources.saml.processors.request import SESSION_KEY_REQUEST_ID from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT from authentik.stages.user_login.stage import BACKEND_INBUILT @@ -119,11 +119,11 @@ class ResponseProcessor: seen_ids.append(self._root.attrib["ID"]) cache.set(CACHE_SEEN_REQUEST_ID % self._source.pk, seen_ids) return - if SESSION_REQUEST_ID not in request.session or "InResponseTo" not in self._root.attrib: + if SESSION_KEY_REQUEST_ID not in request.session or "InResponseTo" not in self._root.attrib: raise MismatchedRequestID( "Missing InResponseTo and IdP-initiated Logins are not allowed" ) - if request.session[SESSION_REQUEST_ID] != self._root.attrib["InResponseTo"]: + if request.session[SESSION_KEY_REQUEST_ID] != self._root.attrib["InResponseTo"]: raise MismatchedRequestID("Mismatched request ID") def _handle_name_id_transient(self, request: HttpRequest) -> HttpResponse: diff --git a/authentik/stages/authenticator_duo/stage.py b/authentik/stages/authenticator_duo/stage.py index a0f890e03..2c6d863da 100644 --- a/authentik/stages/authenticator_duo/stage.py +++ b/authentik/stages/authenticator_duo/stage.py @@ -18,8 +18,8 @@ from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, Duo LOGGER = get_logger() -SESSION_KEY_DUO_USER_ID = "authentik_stages_authenticator_duo_user_id" -SESSION_KEY_DUO_ACTIVATION_CODE = "authentik_stages_authenticator_duo_activation_code" +SESSION_KEY_DUO_USER_ID = "authentik/stages/authenticator_duo/user_id" +SESSION_KEY_DUO_ACTIVATION_CODE = "authentik/stages/authenticator_duo/activation_code" class AuthenticatorDuoChallenge(WithUserInfoChallenge): @@ -95,3 +95,7 @@ class AuthenticatorDuoStageView(ChallengeStageView): else: return self.executor.stage_invalid("Device with Credential ID already exists.") return self.executor.stage_ok() + + def cleanup(self): + self.request.session.pop(SESSION_KEY_DUO_USER_ID) + self.request.session.pop(SESSION_KEY_DUO_ACTIVATION_CODE) diff --git a/authentik/stages/authenticator_sms/stage.py b/authentik/stages/authenticator_sms/stage.py index 8e358066f..581522ce6 100644 --- a/authentik/stages/authenticator_sms/stage.py +++ b/authentik/stages/authenticator_sms/stage.py @@ -20,7 +20,7 @@ from authentik.stages.authenticator_sms.models import AuthenticatorSMSStage, SMS from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT LOGGER = get_logger() -SESSION_SMS_DEVICE = "sms_device" +SESSION_KEY_SMS_DEVICE = "authentik/stages/authenticator_sms/sms_device" class AuthenticatorSMSChallenge(WithUserInfoChallenge): @@ -66,9 +66,9 @@ class AuthenticatorSMSStageView(ChallengeStageView): if "phone" in context.get(PLAN_CONTEXT_PROMPT, {}): LOGGER.debug("got phone number from plan context") return context.get(PLAN_CONTEXT_PROMPT, {}).get("phone") - if SESSION_SMS_DEVICE in self.request.session: + if SESSION_KEY_SMS_DEVICE in self.request.session: LOGGER.debug("got phone number from device in session") - device: SMSDevice = self.request.session[SESSION_SMS_DEVICE] + device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] if device.phone_number == "": return None return device.phone_number @@ -84,7 +84,7 @@ class AuthenticatorSMSStageView(ChallengeStageView): def get_response_instance(self, data: QueryDict) -> ChallengeResponse: response = super().get_response_instance(data) - response.device = self.request.session[SESSION_SMS_DEVICE] + response.device = self.request.session[SESSION_KEY_SMS_DEVICE] return response def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: @@ -100,19 +100,19 @@ class AuthenticatorSMSStageView(ChallengeStageView): stage: AuthenticatorSMSStage = self.executor.current_stage - if SESSION_SMS_DEVICE not in self.request.session: + if SESSION_KEY_SMS_DEVICE not in self.request.session: device = SMSDevice(user=user, confirmed=False, stage=stage, name="SMS Device") device.generate_token(commit=False) if phone_number := self._has_phone_number(): device.phone_number = phone_number - self.request.session[SESSION_SMS_DEVICE] = device + self.request.session[SESSION_KEY_SMS_DEVICE] = device return super().get(request, *args, **kwargs) def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: """SMS Token is validated by challenge""" - device: SMSDevice = self.request.session[SESSION_SMS_DEVICE] + device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] if not device.confirmed: return self.challenge_invalid(response) device.save() - del self.request.session[SESSION_SMS_DEVICE] + del self.request.session[SESSION_KEY_SMS_DEVICE] return self.executor.stage_ok() diff --git a/authentik/stages/authenticator_sms/tests.py b/authentik/stages/authenticator_sms/tests.py index afa6c829a..80ff2aaae 100644 --- a/authentik/stages/authenticator_sms/tests.py +++ b/authentik/stages/authenticator_sms/tests.py @@ -8,7 +8,7 @@ from authentik.core.models import User from authentik.flows.challenge import ChallengeTypes from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding from authentik.stages.authenticator_sms.models import AuthenticatorSMSStage, SMSProviders -from authentik.stages.authenticator_sms.stage import SESSION_SMS_DEVICE +from authentik.stages.authenticator_sms.stage import SESSION_KEY_SMS_DEVICE class AuthenticatorSMSStageTests(APITestCase): @@ -85,7 +85,7 @@ class AuthenticatorSMSStageTests(APITestCase): data={ "component": "ak-stage-authenticator-sms", "phone_number": "foo", - "code": int(self.client.session[SESSION_SMS_DEVICE].token), + "code": int(self.client.session[SESSION_KEY_SMS_DEVICE].token), }, ) self.assertEqual(response.status_code, 200) diff --git a/authentik/stages/authenticator_validate/challenge.py b/authentik/stages/authenticator_validate/challenge.py index cf25f326a..fd6b8b41a 100644 --- a/authentik/stages/authenticator_validate/challenge.py +++ b/authentik/stages/authenticator_validate/challenge.py @@ -22,6 +22,7 @@ from authentik.lib.utils.http import get_client_ip from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice from authentik.stages.authenticator_sms.models import SMSDevice from authentik.stages.authenticator_webauthn.models import WebAuthnDevice +from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id LOGGER = get_logger() @@ -43,23 +44,23 @@ def get_challenge_for_device(request: HttpRequest, device: Device) -> dict: return {} -def get_webauthn_challenge_userless(request: HttpRequest) -> dict: +def get_webauthn_challenge_without_user(request: HttpRequest) -> dict: """Same as `get_webauthn_challenge`, but allows any client device. We can then later check who the device belongs to.""" - request.session.pop("challenge", None) + request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) authentication_options = generate_authentication_options( rp_id=get_rp_id(request), allow_credentials=[], ) - request.session["challenge"] = authentication_options.challenge + request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge return loads(options_to_json(authentication_options)) def get_webauthn_challenge(request: HttpRequest, device: Optional[WebAuthnDevice] = None) -> dict: """Send the client a challenge that we'll check later""" - request.session.pop("challenge", None) + request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) allowed_credentials = [] @@ -74,7 +75,7 @@ def get_webauthn_challenge(request: HttpRequest, device: Optional[WebAuthnDevice allow_credentials=allowed_credentials, ) - request.session["challenge"] = authentication_options.challenge + request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge return loads(options_to_json(authentication_options)) @@ -103,7 +104,7 @@ def validate_challenge_code(code: str, request: HttpRequest, user: User) -> Devi # pylint: disable=unused-argument def validate_challenge_webauthn(data: dict, request: HttpRequest, user: User) -> Device: """Validate WebAuthn Challenge""" - challenge = request.session.get("challenge") + challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE) credential_id = data.get("id") device = WebAuthnDevice.objects.filter(credential_id=credential_id).first() diff --git a/authentik/stages/authenticator_validate/stage.py b/authentik/stages/authenticator_validate/stage.py index 6dab05b8d..8abcb5d7f 100644 --- a/authentik/stages/authenticator_validate/stage.py +++ b/authentik/stages/authenticator_validate/stage.py @@ -26,7 +26,7 @@ from authentik.stages.authenticator_sms.models import SMSDevice from authentik.stages.authenticator_validate.challenge import ( DeviceChallenge, get_challenge_for_device, - get_webauthn_challenge_userless, + get_webauthn_challenge_without_user, select_challenge, validate_challenge_code, validate_challenge_duo, @@ -38,9 +38,10 @@ from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_ME LOGGER = get_logger() COOKIE_NAME_MFA = "authentik_mfa" -SESSION_STAGES = "goauthentik.io/stages/authenticator_validate/stages" -SESSION_SELECTED_STAGE = "goauthentik.io/stages/authenticator_validate/selected_stage" -SESSION_DEVICE_CHALLENGES = "goauthentik.io/stages/authenticator_validate/device_challenges" + +SESSION_KEY_STAGES = "authentik/stages/authenticator_validate/stages" +SESSION_KEY_SELECTED_STAGE = "authentik/stages/authenticator_validate/selected_stage" +SESSION_KEY_DEVICE_CHALLENGES = "authentik/stages/authenticator_validate/device_challenges" class SelectableStageSerializer(PassiveSerializer): @@ -75,7 +76,7 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse): def _challenge_allowed(self, classes: list): device_challenges: list[dict] = self.stage.request.session.get( - SESSION_DEVICE_CHALLENGES, [] + SESSION_KEY_DEVICE_CHALLENGES, [] ) if not any(x["device_class"] in classes for x in device_challenges): raise ValidationError("No compatible device class allowed") @@ -107,7 +108,7 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse): """Check which challenge the user has selected. Actual logic only used for SMS stage.""" # First check if the challenge is valid allowed = False - for device_challenge in self.stage.request.session.get(SESSION_DEVICE_CHALLENGES, []): + for device_challenge in self.stage.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, []): if device_challenge.get("device_class", "") == challenge.get( "device_class", "" ) and device_challenge.get("device_uid", "") == challenge.get("device_uid", ""): @@ -125,11 +126,11 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse): def validate_selected_stage(self, stage_pk: str) -> str: """Check that the selected stage is valid""" - stages = self.stage.request.session.get(SESSION_STAGES, []) + stages = self.stage.request.session.get(SESSION_KEY_STAGES, []) if not any(str(stage.pk) == stage_pk for stage in stages): raise ValidationError("Selected stage is invalid") LOGGER.debug("Setting selected stage to ", stage=stage_pk) - self.stage.request.session[SESSION_SELECTED_STAGE] = stage_pk + self.stage.request.session[SESSION_KEY_SELECTED_STAGE] = stage_pk return stage_pk def validate(self, attrs: dict): @@ -153,7 +154,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): LOGGER.debug("Got devices for user", devices=user_devices) # static and totp are only shown once - # since their challenges are device-independant + # since their challenges are device-independent seen_classes = [] stage: AuthenticatorValidateStage = self.executor.current_stage @@ -168,7 +169,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): continue allowed_devices.append(device) # Ensure only one challenge per device class - # WebAuthn does another device loop to find all webuahtn devices + # WebAuthn does another device loop to find all WebAuthn devices if device_class in seen_classes: continue if device_class not in seen_classes: @@ -188,13 +189,13 @@ class AuthenticatorValidateStageView(ChallengeStageView): self.check_mfa_cookie(allowed_devices) return challenges - def get_userless_webauthn_challenge(self) -> list[dict]: + def get_webauthn_challenge_without_user(self) -> list[dict]: """Get a WebAuthn challenge when no pending user is set.""" challenge = DeviceChallenge( data={ "device_class": DeviceClasses.WEBAUTHN, "device_uid": -1, - "challenge": get_webauthn_challenge_userless(self.request), + "challenge": get_webauthn_challenge_without_user(self.request), } ) challenge.is_valid() @@ -217,12 +218,12 @@ class AuthenticatorValidateStageView(ChallengeStageView): return self.executor.stage_ok() # Passwordless auth, with just webauthn if DeviceClasses.WEBAUTHN in stage.device_classes: - LOGGER.debug("Userless flow, getting generic webauthn challenge") - challenges = self.get_userless_webauthn_challenge() + LOGGER.debug("Flow without user, getting generic webauthn challenge") + challenges = self.get_webauthn_challenge_without_user() else: LOGGER.debug("No pending user, continuing") return self.executor.stage_ok() - self.request.session[SESSION_DEVICE_CHALLENGES] = challenges + self.request.session[SESSION_KEY_DEVICE_CHALLENGES] = challenges # No allowed devices if len(challenges) < 1: @@ -255,23 +256,23 @@ class AuthenticatorValidateStageView(ChallengeStageView): if stage.configuration_stages.count() == 1: next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk) LOGGER.debug("Single stage configured, auto-selecting", stage=next_stage) - self.request.session[SESSION_SELECTED_STAGE] = next_stage - # Because that normal insetion only happens on post, we directly inject it here and + self.request.session[SESSION_KEY_SELECTED_STAGE] = next_stage + # Because that normal execution only happens on post, we directly inject it here and # return it self.executor.plan.insert_stage(next_stage) return self.executor.stage_ok() stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses() - self.request.session[SESSION_STAGES] = stages + self.request.session[SESSION_KEY_STAGES] = stages return super().get(self.request, *args, **kwargs) def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: res = super().post(request, *args, **kwargs) if ( - SESSION_SELECTED_STAGE in self.request.session + SESSION_KEY_SELECTED_STAGE in self.request.session and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE ): LOGGER.debug("Got selected stage in session, running that") - stage_pk = self.request.session.get(SESSION_SELECTED_STAGE) + stage_pk = self.request.session.get(SESSION_KEY_SELECTED_STAGE) # Because the foreign key to stage.configuration_stage points to # a base stage class, we need to do another lookup stage = Stage.objects.get_subclass(pk=stage_pk) @@ -282,8 +283,8 @@ class AuthenticatorValidateStageView(ChallengeStageView): return res def get_challenge(self) -> AuthenticatorValidationChallenge: - challenges = self.request.session.get(SESSION_DEVICE_CHALLENGES, []) - stages = self.request.session.get(SESSION_STAGES, []) + challenges = self.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, []) + stages = self.request.session.get(SESSION_KEY_STAGES, []) stage_challenges = [] for stage in stages: serializer = SelectableStageSerializer( @@ -385,3 +386,8 @@ class AuthenticatorValidateStageView(ChallengeStageView): ) ) return self.set_valid_mfa_cookie(response.device) + + def cleanup(self): + self.request.session.pop(SESSION_KEY_STAGES, None) + self.request.session.pop(SESSION_KEY_SELECTED_STAGE, None) + self.request.session.pop(SESSION_KEY_DEVICE_CHALLENGES, None) diff --git a/authentik/stages/authenticator_validate/tests/test_stage.py b/authentik/stages/authenticator_validate/tests/test_stage.py index 88b9068b9..0001acdb4 100644 --- a/authentik/stages/authenticator_validate/tests/test_stage.py +++ b/authentik/stages/authenticator_validate/tests/test_stage.py @@ -13,7 +13,7 @@ from authentik.lib.tests.utils import dummy_get_response from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageSerializer from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage from authentik.stages.authenticator_validate.stage import ( - SESSION_DEVICE_CHALLENGES, + SESSION_KEY_DEVICE_CHALLENGES, AuthenticatorValidationChallengeResponse, ) from authentik.stages.identification.models import IdentificationStage, UserFields @@ -83,7 +83,7 @@ class AuthenticatorValidateStageTests(FlowTestCase): middleware = SessionMiddleware(dummy_get_response) middleware.process_request(request) - request.session[SESSION_DEVICE_CHALLENGES] = [ + request.session[SESSION_KEY_DEVICE_CHALLENGES] = [ { "device_class": "static", "device_uid": "1", diff --git a/authentik/stages/authenticator_webauthn/stage.py b/authentik/stages/authenticator_webauthn/stage.py index beae565c9..59ed3f32a 100644 --- a/authentik/stages/authenticator_webauthn/stage.py +++ b/authentik/stages/authenticator_webauthn/stage.py @@ -30,7 +30,7 @@ from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id LOGGER = get_logger() -SESSION_KEY_WEBAUTHN_AUTHENTICATED = "authentik_stages_authenticator_webauthn_authenticated" +SESSION_KEY_WEBAUTHN_CHALLENGE = "authentik/stages/authenticator_webauthn/challenge" class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge): @@ -51,7 +51,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse): def validate_response(self, response: dict) -> dict: """Validate webauthn challenge response""" - challenge = self.request.session["challenge"] + challenge = self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] try: registration: VerifiedRegistration = verify_registration_response( @@ -80,7 +80,7 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): def get_challenge(self, *args, **kwargs) -> Challenge: # clear session variables prior to starting a new registration - self.request.session.pop("challenge", None) + self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) stage: AuthenticateWebAuthnStage = self.executor.current_stage user = self.get_pending_user() @@ -103,7 +103,7 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): ), ) - self.request.session["challenge"] = registration_options.challenge + self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = registration_options.challenge self.request.session.save() return AuthenticatorWebAuthnChallenge( data={ @@ -143,3 +143,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): else: return self.executor.stage_invalid("Device with Credential ID already exists.") return self.executor.stage_ok() + + def cleanup(self): + self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE) diff --git a/authentik/stages/password/stage.py b/authentik/stages/password/stage.py index c353102c5..508595a70 100644 --- a/authentik/stages/password/stage.py +++ b/authentik/stages/password/stage.py @@ -30,7 +30,7 @@ LOGGER = get_logger() PLAN_CONTEXT_AUTHENTICATION_BACKEND = "user_backend" PLAN_CONTEXT_METHOD = "auth_method" PLAN_CONTEXT_METHOD_ARGS = "auth_method_args" -SESSION_INVALID_TRIES = "user_invalid_tries" +SESSION_KEY_INVALID_TRIES = "authentik/stages/password/user_invalid_tries" def authenticate(request: HttpRequest, backends: list[str], **credentials: Any) -> Optional[User]: @@ -100,16 +100,16 @@ class PasswordStageView(ChallengeStageView): return challenge def challenge_invalid(self, response: PasswordChallengeResponse) -> HttpResponse: - if SESSION_INVALID_TRIES not in self.request.session: - self.request.session[SESSION_INVALID_TRIES] = 0 - self.request.session[SESSION_INVALID_TRIES] += 1 + if SESSION_KEY_INVALID_TRIES not in self.request.session: + self.request.session[SESSION_KEY_INVALID_TRIES] = 0 + self.request.session[SESSION_KEY_INVALID_TRIES] += 1 current_stage: PasswordStage = self.executor.current_stage if ( - self.request.session[SESSION_INVALID_TRIES] + self.request.session[SESSION_KEY_INVALID_TRIES] > current_stage.failed_attempts_before_cancel ): LOGGER.debug("User has exceeded maximum tries") - del self.request.session[SESSION_INVALID_TRIES] + del self.request.session[SESSION_KEY_INVALID_TRIES] return self.executor.stage_invalid() return super().challenge_invalid(response) diff --git a/authentik/stages/user_write/stage.py b/authentik/stages/user_write/stage.py index d083beb41..63f22b133 100644 --- a/authentik/stages/user_write/stage.py +++ b/authentik/stages/user_write/stage.py @@ -9,7 +9,7 @@ from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ from structlog.stdlib import get_logger -from authentik.core.middleware import SESSION_IMPERSONATE_USER +from authentik.core.middleware import SESSION_KEY_IMPERSONATE_USER from authentik.core.models import USER_ATTRIBUTE_SOURCES, User, UserSourceConnection from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER @@ -117,7 +117,7 @@ class UserWriteStageView(StageView): if ( any("password" in x for x in data.keys()) and self.request.user.pk == user.pk - and SESSION_IMPERSONATE_USER not in self.request.session + and SESSION_KEY_IMPERSONATE_USER not in self.request.session ): should_update_session = True self.update_user(user)