diff --git a/authentik/stages/invitation/stage.py b/authentik/stages/invitation/stage.py index 5385411d4..d48ee03a9 100644 --- a/authentik/stages/invitation/stage.py +++ b/authentik/stages/invitation/stage.py @@ -1,18 +1,23 @@ """invitation stage logic""" -from copy import deepcopy from typing import Optional +from deepmerge import always_merger from django.http import HttpRequest, HttpResponse +from django.http.response import HttpResponseBadRequest from django.shortcuts import get_object_or_404 +from structlog.stdlib import get_logger +from authentik.flows.models import in_memory_stage from authentik.flows.stage import StageView from authentik.flows.views import SESSION_KEY_GET from authentik.stages.invitation.models import Invitation, InvitationStage from authentik.stages.invitation.signals import invitation_used from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT +LOGGER = get_logger() INVITATION_TOKEN_KEY = "token" # nosec INVITATION_IN_EFFECT = "invitation_in_effect" +INVITATION = "invitation" class InvitationStageView(StageView): @@ -39,9 +44,37 @@ class InvitationStageView(StageView): return self.executor.stage_invalid() invite: Invitation = get_object_or_404(Invitation, pk=token) - self.executor.plan.context[PLAN_CONTEXT_PROMPT] = deepcopy(invite.fixed_data) self.executor.plan.context[INVITATION_IN_EFFECT] = True + self.executor.plan.context[INVITATION] = invite + + context = {} + always_merger.merge( + context, self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}) + ) + always_merger.merge(context, invite.fixed_data) + self.executor.plan.context[PLAN_CONTEXT_PROMPT] = context + invitation_used.send(sender=self, request=request, invitation=invite) if invite.single_use: - invite.delete() + self.executor.plan.append_stage(in_memory_stage(InvitationFinalStageView)) + return self.executor.stage_ok() + + +class InvitationFinalStageView(StageView): + """Final stage which is injected by invitation stage. Deletes + the used invitation.""" + + def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + """Call get as this request may be called with post""" + return self.get(request, *args, **kwargs) + + def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + """Delete invitation if single_use is active""" + invitation: Invitation = self.executor.plan.context.get(INVITATION, None) + if not invitation: + LOGGER.warning("InvitationFinalStageView stage called without invitation") + return HttpResponseBadRequest + if not invitation.single_use: + return self.executor.stage_ok() + invitation.delete() return self.executor.stage_ok() diff --git a/authentik/stages/invitation/tests.py b/authentik/stages/invitation/tests.py index 72e9f93bf..c82e2b449 100644 --- a/authentik/stages/invitation/tests.py +++ b/authentik/stages/invitation/tests.py @@ -156,11 +156,13 @@ class TestUserLoginStage(TestCase): base_url = reverse( "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} ) - response = self.client.get(base_url) + response = self.client.get(base_url, follow=True) session = self.client.session plan: FlowPlan = session[SESSION_KEY_PLAN] - self.assertEqual(plan.context[PLAN_CONTEXT_PROMPT], data) + self.assertEqual( + plan.context[PLAN_CONTEXT_PROMPT], data | plan.context[PLAN_CONTEXT_PROMPT] + ) self.assertEqual(response.status_code, 200) self.assertJSONEqual(