stages/prompt: set field default based on placeholder, fix duplicate fields
closes #2572 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
969902f503
commit
74ff9d04dd
|
@ -94,7 +94,10 @@ def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
|||
elif isinstance(value, (HttpRequest, WSGIRequest)):
|
||||
continue
|
||||
elif isinstance(value, type):
|
||||
final_dict[key] = value.__module__ + "." + value.__name__
|
||||
final_dict[key] = {
|
||||
"type": value.__name__,
|
||||
"module": value.__module__,
|
||||
}
|
||||
else:
|
||||
final_dict[key] = value
|
||||
return final_dict
|
||||
|
|
|
@ -55,6 +55,7 @@ class PromptChallengeResponse(ChallengeResponse):
|
|||
stage: PromptStage = kwargs.pop("stage", None)
|
||||
plan: FlowPlan = kwargs.pop("plan", None)
|
||||
request: HttpRequest = kwargs.pop("request", None)
|
||||
user: User = kwargs.pop("user", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.stage = stage
|
||||
self.plan = plan
|
||||
|
@ -65,7 +66,9 @@ class PromptChallengeResponse(ChallengeResponse):
|
|||
fields = list(self.stage.fields.all())
|
||||
for field in fields:
|
||||
field: Prompt
|
||||
current = plan.context.get(PLAN_CONTEXT_PROMPT, {}).get(field.field_key)
|
||||
current = field.get_placeholder(
|
||||
plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request
|
||||
)
|
||||
self.fields[field.field_key] = field.field(current)
|
||||
# Special handling for fields with username type
|
||||
# these check for existing users with the same username
|
||||
|
@ -93,21 +96,7 @@ class PromptChallengeResponse(ChallengeResponse):
|
|||
if len(all_passwords) > 1:
|
||||
raise ValidationError(_("Passwords don't match."))
|
||||
|
||||
def check_empty(self, root: dict) -> dict:
|
||||
"""Check dictionary recursively for empty"""
|
||||
for key, value in root.items():
|
||||
if isinstance(value, dict):
|
||||
root[key] = self.check_empty(value)
|
||||
elif isinstance(value, empty) or value == empty:
|
||||
root[key] = ""
|
||||
return root
|
||||
|
||||
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
||||
# Check if any fields that are allowed to be blank are empty
|
||||
# and replace with an empty string (currently all fields that support
|
||||
# allow_blank are string-based)
|
||||
attrs = self.check_empty(attrs)
|
||||
|
||||
# Check if we have any static or hidden fields, and ensure they
|
||||
# still have the same value
|
||||
static_hidden_fields: QuerySet[Prompt] = self.stage.fields.filter(
|
||||
|
@ -115,7 +104,11 @@ class PromptChallengeResponse(ChallengeResponse):
|
|||
)
|
||||
for static_hidden in static_hidden_fields:
|
||||
field = self.fields[static_hidden.field_key]
|
||||
attrs[static_hidden.field_key] = field.default
|
||||
default = field.default
|
||||
# Prevent rest_framework.fields.empty from ending up in policies and events
|
||||
if default == empty:
|
||||
default = ""
|
||||
attrs[static_hidden.field_key] = default
|
||||
|
||||
# Check if we have two password fields, and make sure they are the same
|
||||
password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD)
|
||||
|
@ -126,7 +119,6 @@ class PromptChallengeResponse(ChallengeResponse):
|
|||
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request)
|
||||
engine.mode = PolicyEngineMode.MODE_ALL
|
||||
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
|
||||
engine.request.context.update(attrs)
|
||||
engine.build()
|
||||
result = engine.result
|
||||
if not result.passing:
|
||||
|
@ -205,6 +197,7 @@ class PromptStageView(ChallengeStageView):
|
|||
request=self.request,
|
||||
stage=self.executor.current_stage,
|
||||
plan=self.executor.plan,
|
||||
user=self.get_pending_user(),
|
||||
)
|
||||
|
||||
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
|
||||
|
|
|
@ -129,7 +129,10 @@ class TestPromptStage(FlowTestCase):
|
|||
def test_valid_challenge_with_policy(self) -> PromptChallengeResponse:
|
||||
"""Test challenge_response validation"""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
expr = "return request.context['password_prompt'] == request.context['password2_prompt']"
|
||||
expr = (
|
||||
"return request.context['prompt_data']['password_prompt'] "
|
||||
"== request.context['prompt_data']['password2_prompt']"
|
||||
)
|
||||
expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr)
|
||||
self.stage.validation_policies.set([expr_policy])
|
||||
self.stage.save()
|
||||
|
@ -274,9 +277,6 @@ class TestPromptStage(FlowTestCase):
|
|||
prompt.get_placeholder(context, self.user, self.factory.get("/")), prompt.placeholder
|
||||
)
|
||||
|
||||
def test_field_types(self):
|
||||
"""Ensure all field types can successfully be created"""
|
||||
|
||||
def test_invalid_save(self):
|
||||
"""Ensure field can't be saved with invalid type"""
|
||||
prompt: Prompt = Prompt(
|
||||
|
|
Reference in a new issue