stages/prompt: use stage.get_pending_user() to fallback to the correct user

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-12-30 20:38:15 +01:00
parent b548ccca6e
commit 9d422918b3
No known key found for this signature in database
2 changed files with 39 additions and 19 deletions

View file

@ -7,14 +7,13 @@ from django.db.models.query import QuerySet
from django.http import HttpRequest, HttpResponse
from django.http.request import QueryDict
from django.utils.translation import gettext_lazy as _
from guardian.shortcuts import get_anonymous_user
from rest_framework.fields import BooleanField, CharField, ChoiceField, IntegerField, empty
from rest_framework.serializers import ValidationError
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.planner import FlowPlan
from authentik.flows.stage import ChallengeStageView
from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode
@ -47,21 +46,23 @@ class PromptChallengeResponse(ChallengeResponse):
"""Validate response, fields are dynamically created based
on the stage"""
stage_instance: PromptStage
component = CharField(default="ak-stage-prompt")
def __init__(self, *args, **kwargs):
stage: PromptStage = kwargs.pop("stage", None)
stage: PromptStage = kwargs.pop("stage_instance", None)
plan: FlowPlan = kwargs.pop("plan", None)
request: HttpRequest = kwargs.pop("request", None)
user: User = kwargs.pop("user", None)
super().__init__(*args, **kwargs)
self.stage = stage
self.stage_instance = stage
self.plan = plan
self.request = request
if not self.stage:
if not self.stage_instance:
return
# list() is called so we only load the fields once
fields = list(self.stage.fields.all())
fields = list(self.stage_instance.fields.all())
for field in fields:
field: Prompt
current = field.get_placeholder(
@ -97,7 +98,7 @@ class PromptChallengeResponse(ChallengeResponse):
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
# Check if we have any static or hidden fields, and ensure they
# still have the same value
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter(
static_hidden_fields: QuerySet[Prompt] = self.stage_instance.fields.filter(
type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC, FieldTypes.TEXT_READ_ONLY]
)
for static_hidden in static_hidden_fields:
@ -109,12 +110,17 @@ class PromptChallengeResponse(ChallengeResponse):
attrs[static_hidden.field_key] = default
# Check if we have two password fields, and make sure they are the same
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD)
password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter(
type=FieldTypes.PASSWORD
)
if password_fields.exists() and password_fields.count() == 2:
self._validate_password_fields(*[field.field_key for field in password_fields])
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request)
engine = ListPolicyEngine(
self.stage_instance.validation_policies.all(),
self.stage.get_pending_user(),
self.request,
)
engine.mode = PolicyEngineMode.MODE_ALL
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
engine.use_cache = False
@ -194,7 +200,8 @@ class PromptStageView(ChallengeStageView):
instance=None,
data=data,
request=self.request,
stage=self.executor.current_stage,
stage_instance=self.executor.current_stage,
stage=self,
plan=self.executor.plan,
user=self.get_pending_user(),
)

View file

@ -10,11 +10,15 @@ from authentik.flows.markers import StageMarker
from authentik.flows.models import FlowStageBinding
from authentik.flows.planner import FlowPlan
from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.flows.views.executor import SESSION_KEY_PLAN, FlowExecutorView
from authentik.lib.generators import generate_id
from authentik.policies.expression.models import ExpressionPolicy
from authentik.stages.prompt.models import FieldTypes, InlineFileField, Prompt, PromptStage
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptChallengeResponse
from authentik.stages.prompt.stage import (
PLAN_CONTEXT_PROMPT,
PromptChallengeResponse,
PromptStageView,
)
class TestPromptStage(FlowTestCase):
@ -106,6 +110,11 @@ class TestPromptStage(FlowTestCase):
self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2)
self.request = RequestFactory().get("/")
self.request.user = create_test_admin_user()
self.flow_executor = FlowExecutorView(request=self.request)
self.stage_view = PromptStageView(self.flow_executor, request=self.request)
def test_inline_file_field(self):
"""test InlineFileField"""
with self.assertRaises(ValidationError):
@ -148,7 +157,11 @@ class TestPromptStage(FlowTestCase):
self.stage.validation_policies.set([expr_policy])
self.stage.save()
challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data
None,
stage_instance=self.stage,
plan=plan,
data=self.prompt_data,
stage=self.stage_view,
)
self.assertEqual(challenge_response.is_valid(), True)
@ -160,7 +173,7 @@ class TestPromptStage(FlowTestCase):
self.stage.validation_policies.set([expr_policy])
self.stage.save()
challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
)
self.assertEqual(challenge_response.is_valid(), False)
@ -180,7 +193,7 @@ class TestPromptStage(FlowTestCase):
self.stage.validation_policies.set([expr_policy])
self.stage.save()
challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
)
self.assertEqual(challenge_response.is_valid(), True)
@ -208,7 +221,7 @@ class TestPromptStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
self.prompt_data["password2_prompt"] = "qwerqwerqr"
challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
)
self.assertEqual(challenge_response.is_valid(), False)
self.assertEqual(
@ -222,7 +235,7 @@ class TestPromptStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
self.prompt_data["username_prompt"] = user.username
challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
)
self.assertEqual(challenge_response.is_valid(), False)
self.assertEqual(
@ -237,7 +250,7 @@ class TestPromptStage(FlowTestCase):
self.prompt_data["hidden_prompt"] = "foo"
self.prompt_data["static_prompt"] = "foo"
challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data
None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
)
self.assertEqual(challenge_response.is_valid(), True)
self.assertNotEqual(challenge_response.validated_data["hidden_prompt"], "foo")