From dd65862bf290c982aed4da82bd2b8d06bf1a8a3d Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Tue, 25 Oct 2022 22:46:15 +0200 Subject: [PATCH] core: show success message when authenticating/enrolling after flow is finished Signed-off-by: Jens Langhammer --- authentik/core/sources/flow_manager.py | 100 ++++++++++++++++++++----- 1 file changed, 80 insertions(+), 20 deletions(-) diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py index d74a8a567..b1c3ccf68 100644 --- a/authentik/core/sources/flow_manager.py +++ b/authentik/core/sources/flow_manager.py @@ -5,7 +5,7 @@ from typing import Any, Optional from django.contrib import messages from django.db import IntegrityError from django.db.models.query_utils import Q -from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest +from django.http import HttpRequest, HttpResponse from django.shortcuts import redirect from django.urls import reverse from django.utils.translation import gettext as _ @@ -23,8 +23,10 @@ from authentik.flows.planner import ( PLAN_CONTEXT_SSO, FlowPlanner, ) +from authentik.flows.stage import StageView from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN from authentik.lib.utils.urls import redirect_with_qs +from authentik.lib.views import bad_request_message from authentik.policies.denied import AccessDeniedResponse from authentik.policies.utils import delete_none_keys from authentik.stages.password import BACKEND_INBUILT @@ -43,6 +45,34 @@ class Action(Enum): DENY = "deny" +def message_stage(message: str, level: int) -> StageView: + """Show a pre-configured message after the flow is done""" + + class MessageStage(StageView): + """Show a pre-configured message after the flow is done""" + + message: str + level: int + + # pylint: disable=unused-argument + def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + """Show a pre-configured message after the flow is done""" + messages.add_message( + self.request, + self.level, + self.message, + ) + return self.executor.stage_ok() + + def post(self, request: HttpRequest) -> HttpResponse: + """Wrapper for post requests""" + return self.get(request) + + MessageStage.message = message + MessageStage.level = level + return MessageStage + + class SourceFlowManager: """Help sources decide what they should do after authorization. Based on source settings and previous connections, authenticate the user, enroll a new user, link to an existing user @@ -156,10 +186,10 @@ class SourceFlowManager: if connection: if action == Action.LINK: self._logger.debug("Linking existing user") - return self.handle_existing_user_link(connection) + return self.handle_existing_link(connection) if action == Action.AUTH: self._logger.debug("Handling auth user") - return self.handle_auth_user(connection) + return self.handle_auth(connection) if action == Action.ENROLL: self._logger.debug("Handling enrollment of new user") return self.handle_enroll(connection) @@ -199,7 +229,11 @@ class SourceFlowManager: return [] def _handle_login_flow( - self, flow: Flow, connection: UserSourceConnection, **kwargs + self, + flow: Flow, + connection: UserSourceConnection, + stages: Optional[list[StageView]] = None, + **kwargs, ) -> HttpResponse: """Prepare Authentication Plan, redirect user FlowExecutor""" # Ensure redirect is carried through when user was trying to @@ -219,12 +253,18 @@ class SourceFlowManager: ) kwargs.update(self.policy_context) if not flow: - return HttpResponseBadRequest() + return bad_request_message( + self.request, + _("Configured flow does not exist."), + ) # We run the Flow planner here so we can pass the Pending user in the context planner = FlowPlanner(flow) plan = planner.plan(self.request, kwargs) for stage in self.get_stages_to_append(flow): - plan.append_stage(stage=stage) + plan.append_stage(stage) + if stages: + for stage in stages: + plan.append_stage(stage) self.request.session[SESSION_KEY_PLAN] = plan return redirect_with_qs( "authentik_core:if-flow", @@ -233,19 +273,30 @@ class SourceFlowManager: ) # pylint: disable=unused-argument - def handle_auth_user( + def handle_auth( self, connection: UserSourceConnection, ) -> HttpResponse: """Login user and redirect.""" - messages.success( - self.request, - _("Successfully authenticated with %(source)s!" % {"source": self.source.name}), - ) flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} - return self._handle_login_flow(self.source.authentication_flow, connection, **flow_kwargs) + return self._handle_login_flow( + self.source.authentication_flow, + connection, + stages=[ + in_memory_stage( + message_stage( + messages.SUCCESS, + _( + "Successfully authenticated with %(source)s!" + % {"source": self.source.name} + ), + ) + ) + ], + **flow_kwargs, + ) - def handle_existing_user_link( + def handle_existing_link( self, connection: UserSourceConnection, ) -> HttpResponse: @@ -263,7 +314,7 @@ class SourceFlowManager: ) # When request isn't authenticated we jump straight to auth if not self.request.user.is_authenticated: - return self.handle_auth_user(connection) + return self.handle_auth(connection) return redirect( reverse( "authentik_core:if-user", @@ -276,18 +327,27 @@ class SourceFlowManager: connection: UserSourceConnection, ) -> HttpResponse: """User was not authenticated and previous request was not authenticated.""" - messages.success( - self.request, - _("Successfully authenticated with %(source)s!" % {"source": self.source.name}), - ) - # We run the Flow planner here so we can pass the Pending user in the context if not self.source.enrollment_flow: self._logger.warning("source has no enrollment flow") - return HttpResponseBadRequest() + return bad_request_message( + self.request, + _("Source is not configured for enrollment."), + ) return self._handle_login_flow( self.source.enrollment_flow, connection, + stages=[ + in_memory_stage( + message_stage( + messages.SUCCESS, + _( + "Successfully authenticated with %(source)s!" + % {"source": self.source.name} + ), + ) + ) + ], **{ PLAN_CONTEXT_PROMPT: delete_none_keys(self.enroll_info), PLAN_CONTEXT_USER_PATH: self.source.get_user_path(),