stages/prompt: use stage.get_pending_user() to fallback to the correct user
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
b548ccca6e
commit
9d422918b3
|
@ -7,14 +7,13 @@ from django.db.models.query import QuerySet
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.http.request import QueryDict
|
from django.http.request import QueryDict
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from guardian.shortcuts import get_anonymous_user
|
|
||||||
from rest_framework.fields import BooleanField, CharField, ChoiceField, IntegerField, empty
|
from rest_framework.fields import BooleanField, CharField, ChoiceField, IntegerField, empty
|
||||||
from rest_framework.serializers import ValidationError
|
from rest_framework.serializers import ValidationError
|
||||||
|
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
from authentik.core.api.utils import PassiveSerializer
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
from authentik.flows.planner import FlowPlan
|
||||||
from authentik.flows.stage import ChallengeStageView
|
from authentik.flows.stage import ChallengeStageView
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode
|
from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode
|
||||||
|
@ -47,21 +46,23 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||||
"""Validate response, fields are dynamically created based
|
"""Validate response, fields are dynamically created based
|
||||||
on the stage"""
|
on the stage"""
|
||||||
|
|
||||||
|
stage_instance: PromptStage
|
||||||
|
|
||||||
component = CharField(default="ak-stage-prompt")
|
component = CharField(default="ak-stage-prompt")
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
stage: PromptStage = kwargs.pop("stage", None)
|
stage: PromptStage = kwargs.pop("stage_instance", None)
|
||||||
plan: FlowPlan = kwargs.pop("plan", None)
|
plan: FlowPlan = kwargs.pop("plan", None)
|
||||||
request: HttpRequest = kwargs.pop("request", None)
|
request: HttpRequest = kwargs.pop("request", None)
|
||||||
user: User = kwargs.pop("user", None)
|
user: User = kwargs.pop("user", None)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.stage = stage
|
self.stage_instance = stage
|
||||||
self.plan = plan
|
self.plan = plan
|
||||||
self.request = request
|
self.request = request
|
||||||
if not self.stage:
|
if not self.stage_instance:
|
||||||
return
|
return
|
||||||
# list() is called so we only load the fields once
|
# list() is called so we only load the fields once
|
||||||
fields = list(self.stage.fields.all())
|
fields = list(self.stage_instance.fields.all())
|
||||||
for field in fields:
|
for field in fields:
|
||||||
field: Prompt
|
field: Prompt
|
||||||
current = field.get_placeholder(
|
current = field.get_placeholder(
|
||||||
|
@ -97,7 +98,7 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||||
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
||||||
# Check if we have any static or hidden fields, and ensure they
|
# Check if we have any static or hidden fields, and ensure they
|
||||||
# still have the same value
|
# still have the same value
|
||||||
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter(
|
static_hidden_fields: QuerySet[Prompt] = self.stage_instance.fields.filter(
|
||||||
type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC, FieldTypes.TEXT_READ_ONLY]
|
type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC, FieldTypes.TEXT_READ_ONLY]
|
||||||
)
|
)
|
||||||
for static_hidden in static_hidden_fields:
|
for static_hidden in static_hidden_fields:
|
||||||
|
@ -109,12 +110,17 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||||
attrs[static_hidden.field_key] = default
|
attrs[static_hidden.field_key] = default
|
||||||
|
|
||||||
# Check if we have two password fields, and make sure they are the same
|
# Check if we have two password fields, and make sure they are the same
|
||||||
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD)
|
password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter(
|
||||||
|
type=FieldTypes.PASSWORD
|
||||||
|
)
|
||||||
if password_fields.exists() and password_fields.count() == 2:
|
if password_fields.exists() and password_fields.count() == 2:
|
||||||
self._validate_password_fields(*[field.field_key for field in password_fields])
|
self._validate_password_fields(*[field.field_key for field in password_fields])
|
||||||
|
|
||||||
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
|
engine = ListPolicyEngine(
|
||||||
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request)
|
self.stage_instance.validation_policies.all(),
|
||||||
|
self.stage.get_pending_user(),
|
||||||
|
self.request,
|
||||||
|
)
|
||||||
engine.mode = PolicyEngineMode.MODE_ALL
|
engine.mode = PolicyEngineMode.MODE_ALL
|
||||||
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
|
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
|
||||||
engine.use_cache = False
|
engine.use_cache = False
|
||||||
|
@ -194,7 +200,8 @@ class PromptStageView(ChallengeStageView):
|
||||||
instance=None,
|
instance=None,
|
||||||
data=data,
|
data=data,
|
||||||
request=self.request,
|
request=self.request,
|
||||||
stage=self.executor.current_stage,
|
stage_instance=self.executor.current_stage,
|
||||||
|
stage=self,
|
||||||
plan=self.executor.plan,
|
plan=self.executor.plan,
|
||||||
user=self.get_pending_user(),
|
user=self.get_pending_user(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,11 +10,15 @@ from authentik.flows.markers import StageMarker
|
||||||
from authentik.flows.models import FlowStageBinding
|
from authentik.flows.models import FlowStageBinding
|
||||||
from authentik.flows.planner import FlowPlan
|
from authentik.flows.planner import FlowPlan
|
||||||
from authentik.flows.tests import FlowTestCase
|
from authentik.flows.tests import FlowTestCase
|
||||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
from authentik.flows.views.executor import SESSION_KEY_PLAN, FlowExecutorView
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.policies.expression.models import ExpressionPolicy
|
from authentik.policies.expression.models import ExpressionPolicy
|
||||||
from authentik.stages.prompt.models import FieldTypes, InlineFileField, Prompt, PromptStage
|
from authentik.stages.prompt.models import FieldTypes, InlineFileField, Prompt, PromptStage
|
||||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptChallengeResponse
|
from authentik.stages.prompt.stage import (
|
||||||
|
PLAN_CONTEXT_PROMPT,
|
||||||
|
PromptChallengeResponse,
|
||||||
|
PromptStageView,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestPromptStage(FlowTestCase):
|
class TestPromptStage(FlowTestCase):
|
||||||
|
@ -106,6 +110,11 @@ class TestPromptStage(FlowTestCase):
|
||||||
|
|
||||||
self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2)
|
self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2)
|
||||||
|
|
||||||
|
self.request = RequestFactory().get("/")
|
||||||
|
self.request.user = create_test_admin_user()
|
||||||
|
self.flow_executor = FlowExecutorView(request=self.request)
|
||||||
|
self.stage_view = PromptStageView(self.flow_executor, request=self.request)
|
||||||
|
|
||||||
def test_inline_file_field(self):
|
def test_inline_file_field(self):
|
||||||
"""test InlineFileField"""
|
"""test InlineFileField"""
|
||||||
with self.assertRaises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
|
@ -148,7 +157,11 @@ class TestPromptStage(FlowTestCase):
|
||||||
self.stage.validation_policies.set([expr_policy])
|
self.stage.validation_policies.set([expr_policy])
|
||||||
self.stage.save()
|
self.stage.save()
|
||||||
challenge_response = PromptChallengeResponse(
|
challenge_response = PromptChallengeResponse(
|
||||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
None,
|
||||||
|
stage_instance=self.stage,
|
||||||
|
plan=plan,
|
||||||
|
data=self.prompt_data,
|
||||||
|
stage=self.stage_view,
|
||||||
)
|
)
|
||||||
self.assertEqual(challenge_response.is_valid(), True)
|
self.assertEqual(challenge_response.is_valid(), True)
|
||||||
|
|
||||||
|
@ -160,7 +173,7 @@ class TestPromptStage(FlowTestCase):
|
||||||
self.stage.validation_policies.set([expr_policy])
|
self.stage.validation_policies.set([expr_policy])
|
||||||
self.stage.save()
|
self.stage.save()
|
||||||
challenge_response = PromptChallengeResponse(
|
challenge_response = PromptChallengeResponse(
|
||||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||||
)
|
)
|
||||||
self.assertEqual(challenge_response.is_valid(), False)
|
self.assertEqual(challenge_response.is_valid(), False)
|
||||||
|
|
||||||
|
@ -180,7 +193,7 @@ class TestPromptStage(FlowTestCase):
|
||||||
self.stage.validation_policies.set([expr_policy])
|
self.stage.validation_policies.set([expr_policy])
|
||||||
self.stage.save()
|
self.stage.save()
|
||||||
challenge_response = PromptChallengeResponse(
|
challenge_response = PromptChallengeResponse(
|
||||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||||
)
|
)
|
||||||
self.assertEqual(challenge_response.is_valid(), True)
|
self.assertEqual(challenge_response.is_valid(), True)
|
||||||
|
|
||||||
|
@ -208,7 +221,7 @@ class TestPromptStage(FlowTestCase):
|
||||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||||
self.prompt_data["password2_prompt"] = "qwerqwerqr"
|
self.prompt_data["password2_prompt"] = "qwerqwerqr"
|
||||||
challenge_response = PromptChallengeResponse(
|
challenge_response = PromptChallengeResponse(
|
||||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||||
)
|
)
|
||||||
self.assertEqual(challenge_response.is_valid(), False)
|
self.assertEqual(challenge_response.is_valid(), False)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -222,7 +235,7 @@ class TestPromptStage(FlowTestCase):
|
||||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||||
self.prompt_data["username_prompt"] = user.username
|
self.prompt_data["username_prompt"] = user.username
|
||||||
challenge_response = PromptChallengeResponse(
|
challenge_response = PromptChallengeResponse(
|
||||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||||
)
|
)
|
||||||
self.assertEqual(challenge_response.is_valid(), False)
|
self.assertEqual(challenge_response.is_valid(), False)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -237,7 +250,7 @@ class TestPromptStage(FlowTestCase):
|
||||||
self.prompt_data["hidden_prompt"] = "foo"
|
self.prompt_data["hidden_prompt"] = "foo"
|
||||||
self.prompt_data["static_prompt"] = "foo"
|
self.prompt_data["static_prompt"] = "foo"
|
||||||
challenge_response = PromptChallengeResponse(
|
challenge_response = PromptChallengeResponse(
|
||||||
None, stage=self.stage, plan=plan, data=self.prompt_data
|
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
|
||||||
)
|
)
|
||||||
self.assertEqual(challenge_response.is_valid(), True)
|
self.assertEqual(challenge_response.is_valid(), True)
|
||||||
self.assertNotEqual(challenge_response.validated_data["hidden_prompt"], "foo")
|
self.assertNotEqual(challenge_response.validated_data["hidden_prompt"], "foo")
|
||||||
|
|
Reference in New Issue