stages/prompt: use stage.get_pending_user() to fallback to the correct user
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
b548ccca6e
commit
9d422918b3
|
@ -7,14 +7,13 @@ from django.db.models.query import QuerySet
|
|||
from django.http import HttpRequest, HttpResponse
|
||||
from django.http.request import QueryDict
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from rest_framework.fields import BooleanField, CharField, ChoiceField, IntegerField, empty
|
||||
from rest_framework.serializers import ValidationError
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.core.models import User
|
||||
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.planner import FlowPlan
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode
|
||||
|
@ -47,21 +46,23 @@ class PromptChallengeResponse(ChallengeResponse):
|
|||
"""Validate response, fields are dynamically created based
|
||||
on the stage"""
|
||||
|
||||
stage_instance: PromptStage
|
||||
|
||||
component = CharField(default="ak-stage-prompt")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
stage: PromptStage = kwargs.pop("stage", None)
|
||||
stage: PromptStage = kwargs.pop("stage_instance", None)
|
||||
plan: FlowPlan = kwargs.pop("plan", None)
|
||||
request: HttpRequest = kwargs.pop("request", None)
|
||||
user: User = kwargs.pop("user", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.stage = stage
|
||||
self.stage_instance = stage
|
||||
self.plan = plan
|
||||
self.request = request
|
||||
if not self.stage:
|
||||
if not self.stage_instance:
|
||||
return
|
||||
# list() is called so we only load the fields once
|
||||
fields = list(self.stage.fields.all())
|
||||
fields = list(self.stage_instance.fields.all())
|
||||
for field in fields:
|
||||
field: Prompt
|
||||
current = field.get_placeholder(
|
||||
|
@ -97,7 +98,7 @@ class PromptChallengeResponse(ChallengeResponse):
|
|||
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
||||
# Check if we have any static or hidden fields, and ensure they
|
||||
# still have the same value
|
||||
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter(
|
||||
static_hidden_fields: QuerySet[Prompt] = self.stage_instance.fields.filter(
|
||||
type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC, FieldTypes.TEXT_READ_ONLY]
|
||||
)
|
||||
for static_hidden in static_hidden_fields:
|
||||
|
@ -109,12 +110,17 @@ class PromptChallengeResponse(ChallengeResponse):
|
|||
attrs[static_hidden.field_key] = default
|
||||
|
||||
# Check if we have two password fields, and make sure they are the same
|
||||
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD)
|
||||
password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter(
|
||||
type=FieldTypes.PASSWORD
|
||||
)
|
||||
if password_fields.exists() and password_fields.count() == 2:
|
||||
self._validate_password_fields(*[field.field_key for field in password_fields])
|
||||
|
||||
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
|
||||
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request)
|
||||
engine = ListPolicyEngine(
|
||||
self.stage_instance.validation_policies.all(),
|
||||
self.stage.get_pending_user(),
|
||||
self.request,
|
||||
)
|
||||
engine.mode = PolicyEngineMode.MODE_ALL
|
||||
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
|
||||
engine.use_cache = False
|
||||
|
@ -194,7 +200,8 @@ class PromptStageView(ChallengeStageView):
|
|||
instance=None,
|
||||
data=data,
|
||||
request=self.request,
|
||||
stage=self.executor.current_stage,
|
||||
stage_instance=self.executor.current_stage,
|
||||
stage=self,
|
||||
plan=self.executor.plan,
|
||||
user=self.get_pending_user(),
|
||||
)
|
||||
|
|
|
@ -10,11 +10,15 @@ from authentik.flows.markers import StageMarker
|
|||
from authentik.flows.models import FlowStageBinding
|
||||
from authentik.flows.planner import FlowPlan
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN, FlowExecutorView
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.expression.models import ExpressionPolicy
|
||||
from authentik.stages.prompt.models import FieldTypes, InlineFileField, Prompt, PromptStage
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptChallengeResponse
|
||||
from authentik.stages.prompt.stage import (
|
||||
PLAN_CONTEXT_PROMPT,
|
||||
PromptChallengeResponse,
|
||||
PromptStageView,
|
||||
)
|
||||
|
||||
|
||||
class TestPromptStage(FlowTestCase):
|
||||
|
@ -106,6 +110,11 @@ class TestPromptStage(FlowTestCase):
|
|||
|
||||
self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2)
|
||||
|
||||
self.request = RequestFactory().get("/")
|
||||
self.request.user = create_test_admin_user()
|
||||
self.flow_executor = FlowExecutorView(request=self.request)
|
||||
self.stage_view = PromptStageView(self.flow_executor, request=self.request)
|
||||
|
||||
def test_inline_file_field(self):
|
||||
"""test InlineFileField"""
|
||||
with self.assertRaises(ValidationError):
|
||||
|
@ -148,7 +157,11 @@ class TestPromptStage(FlowTestCase):
|
|||
self.stage.validation_policies.set([expr_policy])
|
||||
self.stage.save()
|
||||
challenge_response = PromptChallengeResponse(
|
||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
||||
None,
|
||||
stage_instance=self.stage,
|
||||
plan=plan,
|
||||
data=self.prompt_data,
|
||||
stage=self.stage_view,
|
||||
)
|
||||
self.assertEqual(challenge_response.is_valid(), True)
|
||||
|
||||
|
@ -160,7 +173,7 @@ class TestPromptStage(FlowTestCase):
|
|||
self.stage.validation_policies.set([expr_policy])
|
||||
self.stage.save()
|
||||
challenge_response = PromptChallengeResponse(
|
||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
||||
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||
)
|
||||
self.assertEqual(challenge_response.is_valid(), False)
|
||||
|
||||
|
@ -180,7 +193,7 @@ class TestPromptStage(FlowTestCase):
|
|||
self.stage.validation_policies.set([expr_policy])
|
||||
self.stage.save()
|
||||
challenge_response = PromptChallengeResponse(
|
||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
||||
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||
)
|
||||
self.assertEqual(challenge_response.is_valid(), True)
|
||||
|
||||
|
@ -208,7 +221,7 @@ class TestPromptStage(FlowTestCase):
|
|||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
self.prompt_data["password2_prompt"] = "qwerqwerqr"
|
||||
challenge_response = PromptChallengeResponse(
|
||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
||||
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||
)
|
||||
self.assertEqual(challenge_response.is_valid(), False)
|
||||
self.assertEqual(
|
||||
|
@ -222,7 +235,7 @@ class TestPromptStage(FlowTestCase):
|
|||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
self.prompt_data["username_prompt"] = user.username
|
||||
challenge_response = PromptChallengeResponse(
|
||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
||||
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||
)
|
||||
self.assertEqual(challenge_response.is_valid(), False)
|
||||
self.assertEqual(
|
||||
|
@ -237,7 +250,7 @@ class TestPromptStage(FlowTestCase):
|
|||
self.prompt_data["hidden_prompt"] = "foo"
|
||||
self.prompt_data["static_prompt"] = "foo"
|
||||
challenge_response = PromptChallengeResponse(
|
||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
||||
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||
)
|
||||
self.assertEqual(challenge_response.is_valid(), True)
|
||||
self.assertNotEqual(challenge_response.validated_data["hidden_prompt"], "foo")
|
||||
|
|
Reference in New Issue