diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py index 490c25550..bccc4c574 100644 --- a/authentik/core/sources/flow_manager.py +++ b/authentik/core/sources/flow_manager.py @@ -52,6 +52,9 @@ class SourceFlowManager: connection_type: type[UserSourceConnection] = UserSourceConnection + enroll_info: dict[str, Any] + policy_context: dict[str, Any] + def __init__( self, source: Source, @@ -64,6 +67,7 @@ class SourceFlowManager: self.identifier = identifier self.enroll_info = enroll_info self._logger = get_logger().bind(source=source, identifier=identifier) + self.policy_context = {} # pylint: disable=too-many-return-statements def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: @@ -144,7 +148,7 @@ class SourceFlowManager: except IntegrityError as exc: self._logger.warning("failed to get action", exc=exc) return redirect("/") - self._logger.debug("get_action() says", action=action, connection=connection) + self._logger.debug("get_action", action=action, connection=connection) if connection: if action == Action.LINK: self._logger.debug("Linking existing user") @@ -179,7 +183,9 @@ class SourceFlowManager: ] return [] - def _handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse: + def _handle_login_flow( + self, flow: Flow, connection: UserSourceConnection, **kwargs + ) -> HttpResponse: """Prepare Authentication Plan, redirect user FlowExecutor""" # Ensure redirect is carried through when user was trying to # authorize application @@ -193,8 +199,10 @@ class SourceFlowManager: PLAN_CONTEXT_SSO: True, PLAN_CONTEXT_SOURCE: self.source, PLAN_CONTEXT_REDIRECT: final_redirect, + PLAN_CONTEXT_SOURCES_CONNECTION: connection, } ) + kwargs.update(self.policy_context) if not flow: return HttpResponseBadRequest() # We run the Flow planner here so we can pass the Pending user in the context @@ -220,7 +228,7 @@ class SourceFlowManager: _("Successfully authenticated with %(source)s!" % {"source": self.source.name}), ) flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} - return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs) + return self._handle_login_flow(self.source.authentication_flow, connection, **flow_kwargs) def handle_existing_user_link( self, @@ -264,8 +272,8 @@ class SourceFlowManager: return HttpResponseBadRequest() return self._handle_login_flow( self.source.enrollment_flow, + connection, **{ PLAN_CONTEXT_PROMPT: delete_none_keys(self.enroll_info), - PLAN_CONTEXT_SOURCES_CONNECTION: connection, }, ) diff --git a/authentik/sources/oauth/views/callback.py b/authentik/sources/oauth/views/callback.py index b745ac25f..b4ddf2948 100644 --- a/authentik/sources/oauth/views/callback.py +++ b/authentik/sources/oauth/views/callback.py @@ -64,6 +64,7 @@ class OAuthCallback(OAuthClientMixin, View): identifier=identifier, enroll_info=enroll_info, ) + sfm.policy_context = {"oauth_userinfo": raw_info} return sfm.get_flow( access_token=token.get("access_token"), )