stages/prompt: set field default based on placeholder, fix duplicate fields

closes #2572

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-03-23 22:18:03 +01:00
parent 969902f503
commit 74ff9d04dd
3 changed files with 18 additions and 22 deletions

View file

@ -94,7 +94,10 @@ def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
elif isinstance(value, (HttpRequest, WSGIRequest)): elif isinstance(value, (HttpRequest, WSGIRequest)):
continue continue
elif isinstance(value, type): elif isinstance(value, type):
final_dict[key] = value.__module__ + "." + value.__name__ final_dict[key] = {
"type": value.__name__,
"module": value.__module__,
}
else: else:
final_dict[key] = value final_dict[key] = value
return final_dict return final_dict

View file

@ -55,6 +55,7 @@ class PromptChallengeResponse(ChallengeResponse):
stage: PromptStage = kwargs.pop("stage", None) stage: PromptStage = kwargs.pop("stage", None)
plan: FlowPlan = kwargs.pop("plan", None) plan: FlowPlan = kwargs.pop("plan", None)
request: HttpRequest = kwargs.pop("request", None) request: HttpRequest = kwargs.pop("request", None)
user: User = kwargs.pop("user", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.stage = stage self.stage = stage
self.plan = plan self.plan = plan
@ -65,7 +66,9 @@ class PromptChallengeResponse(ChallengeResponse):
fields = list(self.stage.fields.all()) fields = list(self.stage.fields.all())
for field in fields: for field in fields:
field: Prompt field: Prompt
current = plan.context.get(PLAN_CONTEXT_PROMPT, {}).get(field.field_key) current = field.get_placeholder(
plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request
)
self.fields[field.field_key] = field.field(current) self.fields[field.field_key] = field.field(current)
# Special handling for fields with username type # Special handling for fields with username type
# these check for existing users with the same username # these check for existing users with the same username
@ -93,21 +96,7 @@ class PromptChallengeResponse(ChallengeResponse):
if len(all_passwords) > 1: if len(all_passwords) > 1:
raise ValidationError(_("Passwords don't match.")) raise ValidationError(_("Passwords don't match."))
def check_empty(self, root: dict) -> dict:
"""Check dictionary recursively for empty"""
for key, value in root.items():
if isinstance(value, dict):
root[key] = self.check_empty(value)
elif isinstance(value, empty) or value == empty:
root[key] = ""
return root
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
# Check if any fields that are allowed to be blank are empty
# and replace with an empty string (currently all fields that support
# allow_blank are string-based)
attrs = self.check_empty(attrs)
# Check if we have any static or hidden fields, and ensure they # Check if we have any static or hidden fields, and ensure they
# still have the same value # still have the same value
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter( static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter(
@ -115,7 +104,11 @@ class PromptChallengeResponse(ChallengeResponse):
) )
for static_hidden in static_hidden_fields: for static_hidden in static_hidden_fields:
field = self.fields[static_hidden.field_key] field = self.fields[static_hidden.field_key]
attrs[static_hidden.field_key] = field.default default = field.default
# Prevent rest_framework.fields.empty from ending up in policies and events
if default == empty:
default = ""
attrs[static_hidden.field_key] = default
# Check if we have two password fields, and make sure they are the same # Check if we have two password fields, and make sure they are the same
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD) password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD)
@ -126,7 +119,6 @@ class PromptChallengeResponse(ChallengeResponse):
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request) engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request)
engine.mode = PolicyEngineMode.MODE_ALL engine.mode = PolicyEngineMode.MODE_ALL
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
engine.request.context.update(attrs)
engine.build() engine.build()
result = engine.result result = engine.result
if not result.passing: if not result.passing:
@ -205,6 +197,7 @@ class PromptStageView(ChallengeStageView):
request=self.request, request=self.request,
stage=self.executor.current_stage, stage=self.executor.current_stage,
plan=self.executor.plan, plan=self.executor.plan,
user=self.get_pending_user(),
) )
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:

View file

@ -129,7 +129,10 @@ class TestPromptStage(FlowTestCase):
def test_valid_challenge_with_policy(self) -> PromptChallengeResponse: def test_valid_challenge_with_policy(self) -> PromptChallengeResponse:
"""Test challenge_response validation""" """Test challenge_response validation"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
expr = "return request.context['password_prompt'] == request.context['password2_prompt']" expr = (
"return request.context['prompt_data']['password_prompt'] "
"== request.context['prompt_data']['password2_prompt']"
)
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr) expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
self.stage.validation_policies.set([expr_policy]) self.stage.validation_policies.set([expr_policy])
self.stage.save() self.stage.save()
@ -274,9 +277,6 @@ class TestPromptStage(FlowTestCase):
prompt.get_placeholder(context, self.user, self.factory.get("/")), prompt.placeholder prompt.get_placeholder(context, self.user, self.factory.get("/")), prompt.placeholder
) )
def test_field_types(self):
"""Ensure all field types can successfully be created"""
def test_invalid_save(self): def test_invalid_save(self):
"""Ensure field can't be saved with invalid type""" """Ensure field can't be saved with invalid type"""
prompt: Prompt = Prompt( prompt: Prompt = Prompt(