stages/prompt: set field default based on placeholder, fix duplicate fields
closes #2572 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
969902f503
commit
74ff9d04dd
|
@ -94,7 +94,10 @@ def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||||
elif isinstance(value, (HttpRequest, WSGIRequest)):
|
elif isinstance(value, (HttpRequest, WSGIRequest)):
|
||||||
continue
|
continue
|
||||||
elif isinstance(value, type):
|
elif isinstance(value, type):
|
||||||
final_dict[key] = value.__module__ + "." + value.__name__
|
final_dict[key] = {
|
||||||
|
"type": value.__name__,
|
||||||
|
"module": value.__module__,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
final_dict[key] = value
|
final_dict[key] = value
|
||||||
return final_dict
|
return final_dict
|
||||||
|
|
|
@ -55,6 +55,7 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||||
stage: PromptStage = kwargs.pop("stage", None)
|
stage: PromptStage = kwargs.pop("stage", None)
|
||||||
plan: FlowPlan = kwargs.pop("plan", None)
|
plan: FlowPlan = kwargs.pop("plan", None)
|
||||||
request: HttpRequest = kwargs.pop("request", None)
|
request: HttpRequest = kwargs.pop("request", None)
|
||||||
|
user: User = kwargs.pop("user", None)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
self.plan = plan
|
self.plan = plan
|
||||||
|
@ -65,7 +66,9 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||||
fields = list(self.stage.fields.all())
|
fields = list(self.stage.fields.all())
|
||||||
for field in fields:
|
for field in fields:
|
||||||
field: Prompt
|
field: Prompt
|
||||||
current = plan.context.get(PLAN_CONTEXT_PROMPT, {}).get(field.field_key)
|
current = field.get_placeholder(
|
||||||
|
plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request
|
||||||
|
)
|
||||||
self.fields[field.field_key] = field.field(current)
|
self.fields[field.field_key] = field.field(current)
|
||||||
# Special handling for fields with username type
|
# Special handling for fields with username type
|
||||||
# these check for existing users with the same username
|
# these check for existing users with the same username
|
||||||
|
@ -93,21 +96,7 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||||
if len(all_passwords) > 1:
|
if len(all_passwords) > 1:
|
||||||
raise ValidationError(_("Passwords don't match."))
|
raise ValidationError(_("Passwords don't match."))
|
||||||
|
|
||||||
def check_empty(self, root: dict) -> dict:
|
|
||||||
"""Check dictionary recursively for empty"""
|
|
||||||
for key, value in root.items():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
root[key] = self.check_empty(value)
|
|
||||||
elif isinstance(value, empty) or value == empty:
|
|
||||||
root[key] = ""
|
|
||||||
return root
|
|
||||||
|
|
||||||
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
||||||
# Check if any fields that are allowed to be blank are empty
|
|
||||||
# and replace with an empty string (currently all fields that support
|
|
||||||
# allow_blank are string-based)
|
|
||||||
attrs = self.check_empty(attrs)
|
|
||||||
|
|
||||||
# Check if we have any static or hidden fields, and ensure they
|
# Check if we have any static or hidden fields, and ensure they
|
||||||
# still have the same value
|
# still have the same value
|
||||||
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter(
|
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter(
|
||||||
|
@ -115,7 +104,11 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||||
)
|
)
|
||||||
for static_hidden in static_hidden_fields:
|
for static_hidden in static_hidden_fields:
|
||||||
field = self.fields[static_hidden.field_key]
|
field = self.fields[static_hidden.field_key]
|
||||||
attrs[static_hidden.field_key] = field.default
|
default = field.default
|
||||||
|
# Prevent rest_framework.fields.empty from ending up in policies and events
|
||||||
|
if default == empty:
|
||||||
|
default = ""
|
||||||
|
attrs[static_hidden.field_key] = default
|
||||||
|
|
||||||
# Check if we have two password fields, and make sure they are the same
|
# Check if we have two password fields, and make sure they are the same
|
||||||
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD)
|
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD)
|
||||||
|
@ -126,7 +119,6 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||||
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request)
|
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request)
|
||||||
engine.mode = PolicyEngineMode.MODE_ALL
|
engine.mode = PolicyEngineMode.MODE_ALL
|
||||||
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
|
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
|
||||||
engine.request.context.update(attrs)
|
|
||||||
engine.build()
|
engine.build()
|
||||||
result = engine.result
|
result = engine.result
|
||||||
if not result.passing:
|
if not result.passing:
|
||||||
|
@ -205,6 +197,7 @@ class PromptStageView(ChallengeStageView):
|
||||||
request=self.request,
|
request=self.request,
|
||||||
stage=self.executor.current_stage,
|
stage=self.executor.current_stage,
|
||||||
plan=self.executor.plan,
|
plan=self.executor.plan,
|
||||||
|
user=self.get_pending_user(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
|
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
|
||||||
|
|
|
@ -129,7 +129,10 @@ class TestPromptStage(FlowTestCase):
|
||||||
def test_valid_challenge_with_policy(self) -> PromptChallengeResponse:
|
def test_valid_challenge_with_policy(self) -> PromptChallengeResponse:
|
||||||
"""Test challenge_response validation"""
|
"""Test challenge_response validation"""
|
||||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||||
expr = "return request.context['password_prompt'] == request.context['password2_prompt']"
|
expr = (
|
||||||
|
"return request.context['prompt_data']['password_prompt'] "
|
||||||
|
"== request.context['prompt_data']['password2_prompt']"
|
||||||
|
)
|
||||||
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
|
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
|
||||||
self.stage.validation_policies.set([expr_policy])
|
self.stage.validation_policies.set([expr_policy])
|
||||||
self.stage.save()
|
self.stage.save()
|
||||||
|
@ -274,9 +277,6 @@ class TestPromptStage(FlowTestCase):
|
||||||
prompt.get_placeholder(context, self.user, self.factory.get("/")), prompt.placeholder
|
prompt.get_placeholder(context, self.user, self.factory.get("/")), prompt.placeholder
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_field_types(self):
|
|
||||||
"""Ensure all field types can successfully be created"""
|
|
||||||
|
|
||||||
def test_invalid_save(self):
|
def test_invalid_save(self):
|
||||||
"""Ensure field can't be saved with invalid type"""
|
"""Ensure field can't be saved with invalid type"""
|
||||||
prompt: Prompt = Prompt(
|
prompt: Prompt = Prompt(
|
||||||
|
|
Reference in a new issue