diff --git a/passbook/stages/prompt/forms.py b/passbook/stages/prompt/forms.py index e15811cc2..ac520c2da 100644 --- a/passbook/stages/prompt/forms.py +++ b/passbook/stages/prompt/forms.py @@ -1,6 +1,9 @@ """Prompt forms""" from django import forms +from guardian.shortcuts import get_anonymous_user +from passbook.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan +from passbook.policies.engine import PolicyEngine from passbook.stages.prompt.models import Prompt, PromptStage @@ -20,10 +23,22 @@ class PromptForm(forms.Form): """Dynamically created form based on PromptStage""" stage: PromptStage + plan: FlowPlan - def __init__(self, stage: PromptStage, *args, **kwargs): + def __init__(self, stage: PromptStage, plan: FlowPlan, *args, **kwargs): self.stage = stage + self.plan = plan super().__init__(*args, **kwargs) for field in self.stage.fields.all(): field: Prompt self.fields[field.field_key] = field.field + + def clean(self): + cleaned_data = super().clean() + user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) + engine = PolicyEngine(self.stage.policies.all(), user) + engine.request.context = cleaned_data + engine.build() + passing, messages = engine.result + if not passing: + raise forms.ValidationError(messages) diff --git a/passbook/stages/prompt/stage.py b/passbook/stages/prompt/stage.py index 08c2601a7..c7b256d46 100644 --- a/passbook/stages/prompt/stage.py +++ b/passbook/stages/prompt/stage.py @@ -25,6 +25,7 @@ class PromptStageView(FormView, AuthenticationStage): def get_form_kwargs(self): kwargs = super().get_form_kwargs() kwargs["stage"] = self.executor.current_stage + kwargs["plan"] = self.executor.plan return kwargs def form_valid(self, form: PromptForm) -> HttpResponse: diff --git a/passbook/stages/prompt/tests.py b/passbook/stages/prompt/tests.py index 3551a2e47..d4d2e108d 100644 --- a/passbook/stages/prompt/tests.py +++ b/passbook/stages/prompt/tests.py @@ -8,6 +8,8 @@ from passbook.core.models import User from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding from passbook.flows.planner import FlowPlan from passbook.flows.views import SESSION_KEY_PLAN +from passbook.policies.expression.models import ExpressionPolicy +from passbook.policies.models import PolicyBinding from passbook.stages.prompt.forms import PromptForm from passbook.stages.prompt.models import FieldTypes, Prompt, PromptStage from passbook.stages.prompt.stage import PLAN_CONTEXT_PROMPT @@ -47,6 +49,13 @@ class TestPromptStage(TestCase): required=True, placeholder="PASSWORD_PLACEHOLDER", ) + password2_prompt = Prompt.objects.create( + field_key="password2_prompt", + label="PASSWORD_LABEL", + type=FieldTypes.PASSWORD, + required=True, + placeholder="PASSWORD_PLACEHOLDER", + ) number_prompt = Prompt.objects.create( field_key="number_prompt", label="NUMBER_LABEL", @@ -62,7 +71,14 @@ class TestPromptStage(TestCase): ) self.stage = PromptStage.objects.create(name="prompt-stage") self.stage.fields.set( - [text_prompt, email_prompt, password_prompt, number_prompt, hidden_prompt,] + [ + text_prompt, + email_prompt, + password_prompt, + password2_prompt, + number_prompt, + hidden_prompt, + ] ) self.stage.save() @@ -70,6 +86,7 @@ class TestPromptStage(TestCase): text_prompt.field_key: "test-input", email_prompt.field_key: "test@test.test", password_prompt.field_key: "test", + password2_prompt.field_key: "test", number_prompt.field_key: 3, hidden_prompt.field_key: hidden_prompt.placeholder, } @@ -115,10 +132,30 @@ class TestPromptStage(TestCase): def test_valid_form(self) -> PromptForm: """Test form validation""" - form = PromptForm(stage=self.stage, data=self.prompt_data) + plan = FlowPlan(flow_pk=self.flow.pk.hex, stages=[self.stage]) + expr = ( + "{{ request.context.password_prompt == request.context.password2_prompt }}" + ) + expr_policy = ExpressionPolicy.objects.create( + name="validate-form", expression=expr + ) + PolicyBinding.objects.create(policy=expr_policy, target=self.stage) + form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data) self.assertEqual(form.is_valid(), True) return form + def test_invalid_form(self) -> PromptForm: + """Test form validation""" + plan = FlowPlan(flow_pk=self.flow.pk.hex, stages=[self.stage]) + expr = "False" + expr_policy = ExpressionPolicy.objects.create( + name="validate-form", expression=expr + ) + PolicyBinding.objects.create(policy=expr_policy, target=self.stage) + form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data) + self.assertEqual(form.is_valid(), False) + return form + def test_valid_form_request(self): """Test a request with valid form data""" plan = FlowPlan(flow_pk=self.flow.pk.hex, stages=[self.stage])