stages/prompt: use stage.get_pending_user() to fallback to the correct user

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-12-30 20:38:15 +01:00
parent b548ccca6e
commit 9d422918b3
No known key found for this signature in database
2 changed files with 39 additions and 19 deletions

View File

@ -7,14 +7,13 @@ from django.db.models.query import QuerySet
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.http.request import QueryDict from django.http.request import QueryDict
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from guardian.shortcuts import get_anonymous_user
from rest_framework.fields import BooleanField, CharField, ChoiceField, IntegerField, empty from rest_framework.fields import BooleanField, CharField, ChoiceField, IntegerField, empty
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User from authentik.core.models import User
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan from authentik.flows.planner import FlowPlan
from authentik.flows.stage import ChallengeStageView from authentik.flows.stage import ChallengeStageView
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode
@ -47,21 +46,23 @@ class PromptChallengeResponse(ChallengeResponse):
"""Validate response, fields are dynamically created based """Validate response, fields are dynamically created based
on the stage""" on the stage"""
stage_instance: PromptStage
component = CharField(default="ak-stage-prompt") component = CharField(default="ak-stage-prompt")
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
stage: PromptStage = kwargs.pop("stage", None) stage: PromptStage = kwargs.pop("stage_instance", None)
plan: FlowPlan = kwargs.pop("plan", None) plan: FlowPlan = kwargs.pop("plan", None)
request: HttpRequest = kwargs.pop("request", None) request: HttpRequest = kwargs.pop("request", None)
user: User = kwargs.pop("user", None) user: User = kwargs.pop("user", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.stage = stage self.stage_instance = stage
self.plan = plan self.plan = plan
self.request = request self.request = request
if not self.stage: if not self.stage_instance:
return return
# list() is called so we only load the fields once # list() is called so we only load the fields once
fields = list(self.stage.fields.all()) fields = list(self.stage_instance.fields.all())
for field in fields: for field in fields:
field: Prompt field: Prompt
current = field.get_placeholder( current = field.get_placeholder(
@ -97,7 +98,7 @@ class PromptChallengeResponse(ChallengeResponse):
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
# Check if we have any static or hidden fields, and ensure they # Check if we have any static or hidden fields, and ensure they
# still have the same value # still have the same value
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter( static_hidden_fields: QuerySet[Prompt] = self.stage_instance.fields.filter(
type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC, FieldTypes.TEXT_READ_ONLY] type__in=[FieldTypes.HIDDEN, FieldTypes.STATIC, FieldTypes.TEXT_READ_ONLY]
) )
for static_hidden in static_hidden_fields: for static_hidden in static_hidden_fields:
@ -109,12 +110,17 @@ class PromptChallengeResponse(ChallengeResponse):
attrs[static_hidden.field_key] = default attrs[static_hidden.field_key] = default
# Check if we have two password fields, and make sure they are the same # Check if we have two password fields, and make sure they are the same
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD) password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter(
type=FieldTypes.PASSWORD
)
if password_fields.exists() and password_fields.count() == 2: if password_fields.exists() and password_fields.count() == 2:
self._validate_password_fields(*[field.field_key for field in password_fields]) self._validate_password_fields(*[field.field_key for field in password_fields])
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) engine = ListPolicyEngine(
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request) self.stage_instance.validation_policies.all(),
self.stage.get_pending_user(),
self.request,
)
engine.mode = PolicyEngineMode.MODE_ALL engine.mode = PolicyEngineMode.MODE_ALL
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
engine.use_cache = False engine.use_cache = False
@ -194,7 +200,8 @@ class PromptStageView(ChallengeStageView):
instance=None, instance=None,
data=data, data=data,
request=self.request, request=self.request,
stage=self.executor.current_stage, stage_instance=self.executor.current_stage,
stage=self,
plan=self.executor.plan, plan=self.executor.plan,
user=self.get_pending_user(), user=self.get_pending_user(),
) )

View File

@ -10,11 +10,15 @@ from authentik.flows.markers import StageMarker
from authentik.flows.models import FlowStageBinding from authentik.flows.models import FlowStageBinding
from authentik.flows.planner import FlowPlan from authentik.flows.planner import FlowPlan
from authentik.flows.tests import FlowTestCase from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.flows.views.executor import SESSION_KEY_PLAN, FlowExecutorView
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.policies.expression.models import ExpressionPolicy from authentik.policies.expression.models import ExpressionPolicy
from authentik.stages.prompt.models import FieldTypes, InlineFileField, Prompt, PromptStage from authentik.stages.prompt.models import FieldTypes, InlineFileField, Prompt, PromptStage
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptChallengeResponse from authentik.stages.prompt.stage import (
PLAN_CONTEXT_PROMPT,
PromptChallengeResponse,
PromptStageView,
)
class TestPromptStage(FlowTestCase): class TestPromptStage(FlowTestCase):
@ -106,6 +110,11 @@ class TestPromptStage(FlowTestCase):
self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2)
self.request = RequestFactory().get("/")
self.request.user = create_test_admin_user()
self.flow_executor = FlowExecutorView(request=self.request)
self.stage_view = PromptStageView(self.flow_executor, request=self.request)
def test_inline_file_field(self): def test_inline_file_field(self):
"""test InlineFileField""" """test InlineFileField"""
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
@ -148,7 +157,11 @@ class TestPromptStage(FlowTestCase):
self.stage.validation_policies.set([expr_policy]) self.stage.validation_policies.set([expr_policy])
self.stage.save() self.stage.save()
challenge_response = PromptChallengeResponse( challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data None,
stage_instance=self.stage,
plan=plan,
data=self.prompt_data,
stage=self.stage_view,
) )
self.assertEqual(challenge_response.is_valid(), True) self.assertEqual(challenge_response.is_valid(), True)
@ -160,7 +173,7 @@ class TestPromptStage(FlowTestCase):
self.stage.validation_policies.set([expr_policy]) self.stage.validation_policies.set([expr_policy])
self.stage.save() self.stage.save()
challenge_response = PromptChallengeResponse( challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
) )
self.assertEqual(challenge_response.is_valid(), False) self.assertEqual(challenge_response.is_valid(), False)
@ -180,7 +193,7 @@ class TestPromptStage(FlowTestCase):
self.stage.validation_policies.set([expr_policy]) self.stage.validation_policies.set([expr_policy])
self.stage.save() self.stage.save()
challenge_response = PromptChallengeResponse( challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
) )
self.assertEqual(challenge_response.is_valid(), True) self.assertEqual(challenge_response.is_valid(), True)
@ -208,7 +221,7 @@ class TestPromptStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
self.prompt_data["password2_prompt"] = "qwerqwerqr" self.prompt_data["password2_prompt"] = "qwerqwerqr"
challenge_response = PromptChallengeResponse( challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
) )
self.assertEqual(challenge_response.is_valid(), False) self.assertEqual(challenge_response.is_valid(), False)
self.assertEqual( self.assertEqual(
@ -222,7 +235,7 @@ class TestPromptStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
self.prompt_data["username_prompt"] = user.username self.prompt_data["username_prompt"] = user.username
challenge_response = PromptChallengeResponse( challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
) )
self.assertEqual(challenge_response.is_valid(), False) self.assertEqual(challenge_response.is_valid(), False)
self.assertEqual( self.assertEqual(
@ -237,7 +250,7 @@ class TestPromptStage(FlowTestCase):
self.prompt_data["hidden_prompt"] = "foo" self.prompt_data["hidden_prompt"] = "foo"
self.prompt_data["static_prompt"] = "foo" self.prompt_data["static_prompt"] = "foo"
challenge_response = PromptChallengeResponse( challenge_response = PromptChallengeResponse(
None, stage=self.stage, plan=plan, data=self.prompt_data None, stage_instance=self.stage, plan=plan, data=self.prompt_data, stage=self.stage_view
) )
self.assertEqual(challenge_response.is_valid(), True) self.assertEqual(challenge_response.is_valid(), True)
self.assertNotEqual(challenge_response.validated_data["hidden_prompt"], "foo") self.assertNotEqual(challenge_response.validated_data["hidden_prompt"], "foo")