core: passthrough connection and additional data to FlowManager

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

#2047
This commit is contained in:
Jens Langhammer 2022-01-03 21:31:09 +01:00
parent 4c166dcf52
commit a101d48b5a
2 changed files with 13 additions and 4 deletions

View file

@ -52,6 +52,9 @@ class SourceFlowManager:
connection_type: type[UserSourceConnection] = UserSourceConnection connection_type: type[UserSourceConnection] = UserSourceConnection
enroll_info: dict[str, Any]
policy_context: dict[str, Any]
def __init__( def __init__(
self, self,
source: Source, source: Source,
@ -64,6 +67,7 @@ class SourceFlowManager:
self.identifier = identifier self.identifier = identifier
self.enroll_info = enroll_info self.enroll_info = enroll_info
self._logger = get_logger().bind(source=source, identifier=identifier) self._logger = get_logger().bind(source=source, identifier=identifier)
self.policy_context = {}
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]:
@ -144,7 +148,7 @@ class SourceFlowManager:
except IntegrityError as exc: except IntegrityError as exc:
self._logger.warning("failed to get action", exc=exc) self._logger.warning("failed to get action", exc=exc)
return redirect("/") return redirect("/")
self._logger.debug("get_action() says", action=action, connection=connection) self._logger.debug("get_action", action=action, connection=connection)
if connection: if connection:
if action == Action.LINK: if action == Action.LINK:
self._logger.debug("Linking existing user") self._logger.debug("Linking existing user")
@ -179,7 +183,9 @@ class SourceFlowManager:
] ]
return [] return []
def _handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse: def _handle_login_flow(
self, flow: Flow, connection: UserSourceConnection, **kwargs
) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor""" """Prepare Authentication Plan, redirect user FlowExecutor"""
# Ensure redirect is carried through when user was trying to # Ensure redirect is carried through when user was trying to
# authorize application # authorize application
@ -193,8 +199,10 @@ class SourceFlowManager:
PLAN_CONTEXT_SSO: True, PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_SOURCE: self.source, PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_REDIRECT: final_redirect, PLAN_CONTEXT_REDIRECT: final_redirect,
PLAN_CONTEXT_SOURCES_CONNECTION: connection,
} }
) )
kwargs.update(self.policy_context)
if not flow: if not flow:
return HttpResponseBadRequest() return HttpResponseBadRequest()
# We run the Flow planner here so we can pass the Pending user in the context # We run the Flow planner here so we can pass the Pending user in the context
@ -220,7 +228,7 @@ class SourceFlowManager:
_("Successfully authenticated with %(source)s!" % {"source": self.source.name}), _("Successfully authenticated with %(source)s!" % {"source": self.source.name}),
) )
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs) return self._handle_login_flow(self.source.authentication_flow, connection, **flow_kwargs)
def handle_existing_user_link( def handle_existing_user_link(
self, self,
@ -264,8 +272,8 @@ class SourceFlowManager:
return HttpResponseBadRequest() return HttpResponseBadRequest()
return self._handle_login_flow( return self._handle_login_flow(
self.source.enrollment_flow, self.source.enrollment_flow,
connection,
**{ **{
PLAN_CONTEXT_PROMPT: delete_none_keys(self.enroll_info), PLAN_CONTEXT_PROMPT: delete_none_keys(self.enroll_info),
PLAN_CONTEXT_SOURCES_CONNECTION: connection,
}, },
) )

View file

@ -64,6 +64,7 @@ class OAuthCallback(OAuthClientMixin, View):
identifier=identifier, identifier=identifier,
enroll_info=enroll_info, enroll_info=enroll_info,
) )
sfm.policy_context = {"oauth_userinfo": raw_info}
return sfm.get_flow( return sfm.get_flow(
access_token=token.get("access_token"), access_token=token.get("access_token"),
) )