diff --git a/authentik/stages/prompt/stage.py b/authentik/stages/prompt/stage.py index f68329c7a..744446fe1 100644 --- a/authentik/stages/prompt/stage.py +++ b/authentik/stages/prompt/stage.py @@ -7,14 +7,13 @@ from django.db.models.query import QuerySet from django.http import HttpRequest, HttpResponse from django.http.request import QueryDict from django.utils.translation import gettext_lazy as _ -from guardian.shortcuts import get_anonymous_user from rest_framework.fields import BooleanField, CharField, ChoiceField, IntegerField, empty from rest_framework.serializers import ValidationError from authentik.core.api.utils import PassiveSerializer from authentik.core.models import User from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes -from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan +from authentik.flows.planner import FlowPlan from authentik.flows.stage import ChallengeStageView from authentik.policies.engine import PolicyEngine from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode @@ -47,21 +46,23 @@ class PromptChallengeResponse(ChallengeResponse): """Validate response, fields are dynamically created based on the stage""" + stage_instance: PromptStage + component = CharField(default="ak-stage-prompt") def __init__(self, *args, **kwargs): - stage: PromptStage = kwargs.pop("stage", None) + stage: PromptStage = kwargs.pop("stage_instance", None) plan: FlowPlan = kwargs.pop("plan", None) request: HttpRequest = kwargs.pop("request", None) user: User = kwargs.pop("user", None) super().__init__(*args, **kwargs) - self.stage = stage + self.stage_instance = stage self.plan = plan self.request = request - if not self.stage: + if not self.stage_instance: return # list() is called so we only load the fields once - fields = list(self.stage.fields.all()) + fields = list(self.stage_instance.fields.all()) for field in fields: field: Prompt current = field.get_placeholder( @@ -97,7 +98,7 @@ class PromptChallengeResponse(ChallengeResponse): def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: # Check if we have any static or hidden fields, and ensure they # still have the same value - static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter( + static_hidden_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC, FieldTypes.TEXT_READ_ONLY] ) for static_hidden in static_hidden_fields: @@ -109,12 +110,17 @@ class PromptChallengeResponse(ChallengeResponse): attrs[static_hidden.field_key] = default # Check if we have two password fields, and make sure they are the same - password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD) + password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( + type=FieldTypes.PASSWORD + ) if password_fields.exists() and password_fields.count() == 2: self._validate_password_fields(*[field.field_key for field in password_fields]) - user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) - engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request) + engine = ListPolicyEngine( + self.stage_instance.validation_policies.all(), + self.stage.get_pending_user(), + self.request, + ) engine.mode = PolicyEngineMode.MODE_ALL engine.request.context[PLAN_CONTEXT_PROMPT] = attrs engine.use_cache = False @@ -194,7 +200,8 @@ class PromptStageView(ChallengeStageView): instance=None, data=data, request=self.request, - stage=self.executor.current_stage, + stage_instance=self.executor.current_stage, + stage=self, plan=self.executor.plan, user=self.get_pending_user(), ) diff --git a/authentik/stages/prompt/tests.py b/authentik/stages/prompt/tests.py index ee57f125c..cda5ab8dc 100644 --- a/authentik/stages/prompt/tests.py +++ b/authentik/stages/prompt/tests.py @@ -10,11 +10,15 @@ from authentik.flows.markers import StageMarker from authentik.flows.models import FlowStageBinding from authentik.flows.planner import FlowPlan from authentik.flows.tests import FlowTestCase -from authentik.flows.views.executor import SESSION_KEY_PLAN +from authentik.flows.views.executor import SESSION_KEY_PLAN, FlowExecutorView from authentik.lib.generators import generate_id from authentik.policies.expression.models import ExpressionPolicy from authentik.stages.prompt.models import FieldTypes, InlineFileField, Prompt, PromptStage -from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptChallengeResponse +from authentik.stages.prompt.stage import ( + PLAN_CONTEXT_PROMPT, + PromptChallengeResponse, + PromptStageView, +) class TestPromptStage(FlowTestCase): @@ -106,6 +110,11 @@ class TestPromptStage(FlowTestCase): self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) + self.request = RequestFactory().get("/") + self.request.user = create_test_admin_user() + self.flow_executor = FlowExecutorView(request=self.request) + self.stage_view = PromptStageView(self.flow_executor, request=self.request) + def test_inline_file_field(self): """test InlineFileField""" with self.assertRaises(ValidationError): @@ -148,7 +157,11 @@ class TestPromptStage(FlowTestCase): self.stage.validation_policies.set([expr_policy]) self.stage.save() challenge_response = PromptChallengeResponse( - None, stage=self.stage, plan=plan, data=self.prompt_data + None, + stage_instance=self.stage, + plan=plan, + data=self.prompt_data, + stage=self.stage_view, ) self.assertEqual(challenge_response.is_valid(), True) @@ -160,7 +173,7 @@ class TestPromptStage(FlowTestCase): self.stage.validation_policies.set([expr_policy]) self.stage.save() challenge_response = PromptChallengeResponse( - None, stage=self.stage, plan=plan, data=self.prompt_data + None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view ) self.assertEqual(challenge_response.is_valid(), False) @@ -180,7 +193,7 @@ class TestPromptStage(FlowTestCase): self.stage.validation_policies.set([expr_policy]) self.stage.save() challenge_response = PromptChallengeResponse( - None, stage=self.stage, plan=plan, data=self.prompt_data + None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view ) self.assertEqual(challenge_response.is_valid(), True) @@ -208,7 +221,7 @@ class TestPromptStage(FlowTestCase): plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) self.prompt_data["password2_prompt"] = "qwerqwerqr" challenge_response = PromptChallengeResponse( - None, stage=self.stage, plan=plan, data=self.prompt_data + None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view ) self.assertEqual(challenge_response.is_valid(), False) self.assertEqual( @@ -222,7 +235,7 @@ class TestPromptStage(FlowTestCase): plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) self.prompt_data["username_prompt"] = user.username challenge_response = PromptChallengeResponse( - None, stage=self.stage, plan=plan, data=self.prompt_data + None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view ) self.assertEqual(challenge_response.is_valid(), False) self.assertEqual( @@ -237,7 +250,7 @@ class TestPromptStage(FlowTestCase): self.prompt_data["hidden_prompt"] = "foo" self.prompt_data["static_prompt"] = "foo" challenge_response = PromptChallengeResponse( - None, stage=self.stage, plan=plan, data=self.prompt_data + None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view ) self.assertEqual(challenge_response.is_valid(), True) self.assertNotEqual(challenge_response.validated_data["hidden_prompt"], "foo")