flows: add Planner and Executor unittests

This commit is contained in:
Jens Langhammer 2020-05-11 15:01:14 +02:00
parent fc9f86cccc
commit 9814d3be03
6 changed files with 239 additions and 13 deletions

View File

@ -0,0 +1,24 @@
"""miscellaneous flow tests"""
from django.test import TestCase
from passbook.flows.api import StageSerializer, StageViewSet
from passbook.flows.models import Stage
from passbook.stages.dummy.models import DummyStage
class TestFlowsMisc(TestCase):
"""miscellaneous tests"""
def test_models(self):
"""Test that ui_user_settings returns none"""
self.assertIsNone(Stage().ui_user_settings)
def test_api_serializer(self):
"""Test that stage serializer returns the correct type"""
obj = DummyStage()
self.assertEqual(StageSerializer().get_type(obj), "dummy")
def test_api_viewset(self):
"""Test that stage serializer returns the correct type"""
dummy = DummyStage.objects.create()
self.assertIn(dummy, StageViewSet().get_queryset())

View File

@ -0,0 +1,146 @@
"""flow views tests"""
from unittest.mock import MagicMock, patch
from django.shortcuts import reverse
from django.test import Client, TestCase
from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException
from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding
from passbook.flows.planner import FlowPlan
from passbook.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN
from passbook.lib.config import CONFIG
from passbook.stages.dummy.models import DummyStage
POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],))
class TestFlowExecutor(TestCase):
"""Test views logic"""
def setUp(self):
self.client = Client()
def test_invalid_domain(self):
"""Check that an invalid domain triggers the correct message"""
flow = Flow.objects.create(
name="test-empty",
slug="test-empty",
designation=FlowDesignation.AUTHENTICATION,
)
wrong_domain = CONFIG.y("domain") + "-invalid:8000"
response = self.client.get(
reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}),
HTTP_HOST=wrong_domain,
)
self.assertEqual(response.status_code, 400)
self.assertIn("match", response.rendered_content)
self.assertIn(CONFIG.y("domain"), response.rendered_content)
self.assertIn(wrong_domain.split(":")[0], response.rendered_content)
def test_existing_plan_diff_flow(self):
"""Check that a plan for a different flow cancels the current plan"""
flow = Flow.objects.create(
name="test-existing-plan-diff",
slug="test-existing-plan-diff",
designation=FlowDesignation.AUTHENTICATION,
)
stage = DummyStage.objects.create(name="dummy")
plan = FlowPlan(flow_pk=flow.pk.hex + "a", stages=[stage])
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
cancel_mock = MagicMock()
with patch("passbook.flows.views.FlowExecutorView.cancel", cancel_mock):
response = self.client.get(
reverse(
"passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}
),
)
self.assertEqual(response.status_code, 400)
self.assertEqual(cancel_mock.call_count, 1)
@patch(
"passbook.flows.planner.FlowPlanner._check_flow_root_policies",
POLICY_RESULT_MOCK,
)
def test_invalid_non_applicable_flow(self):
"""Tests that a non-applicable flow returns the correct error message"""
flow = Flow.objects.create(
name="test-non-applicable",
slug="test-non-applicable",
designation=FlowDesignation.AUTHENTICATION,
)
CONFIG.update_from_dict({"domain": "testserver"})
response = self.client.get(
reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}),
)
self.assertEqual(response.status_code, 400)
self.assertInHTML(FlowNonApplicableException.__doc__, response.rendered_content)
def test_invalid_empty_flow(self):
"""Tests that an empty flow returns the correct error message"""
flow = Flow.objects.create(
name="test-empty",
slug="test-empty",
designation=FlowDesignation.AUTHENTICATION,
)
CONFIG.update_from_dict({"domain": "testserver"})
response = self.client.get(
reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}),
)
self.assertEqual(response.status_code, 400)
self.assertInHTML(EmptyFlowException.__doc__, response.rendered_content)
def test_invalid_flow_redirect(self):
"""Tests that an invalid flow still redirects"""
flow = Flow.objects.create(
name="test-empty",
slug="test-empty",
designation=FlowDesignation.AUTHENTICATION,
)
CONFIG.update_from_dict({"domain": "testserver"})
dest = "/unique-string"
response = self.client.get(
reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug})
+ f"?{NEXT_ARG_NAME}={dest}"
)
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, dest)
def test_multi_stage_flow(self):
"""Test a full flow with multiple stages"""
flow = Flow.objects.create(
name="test-full",
slug="test-full",
designation=FlowDesignation.AUTHENTICATION,
)
FlowStageBinding.objects.create(
flow=flow, stage=DummyStage.objects.create(name="dummy1"), order=0
)
FlowStageBinding.objects.create(
flow=flow, stage=DummyStage.objects.create(name="dummy2"), order=1
)
exec_url = reverse(
"passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}
)
# First Request, start planning, renders form
response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200)
# Check that two stages are in plan
session = self.client.session
plan: FlowPlan = session[SESSION_KEY_PLAN]
self.assertEqual(len(plan.stages), 2)
# Second request, submit form, one stage left
response = self.client.post(exec_url)
# Second request redirects to the same URL
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, exec_url)
# Check that two stages are in plan
session = self.client.session
plan: FlowPlan = session[SESSION_KEY_PLAN]
self.assertEqual(len(plan.stages), 1)

View File

@ -0,0 +1,39 @@
"""flow views tests"""
from django.shortcuts import reverse
from django.test import Client, TestCase
from passbook.flows.models import Flow, FlowDesignation
from passbook.flows.planner import FlowPlan
from passbook.flows.views import SESSION_KEY_PLAN
class TestHelperView(TestCase):
"""Test helper views logic"""
def setUp(self):
self.client = Client()
def test_default_view(self):
"""Test that ToDefaultFlow returns the expected URL"""
flow = Flow.objects.filter(designation=FlowDesignation.INVALIDATION,).first()
response = self.client.get(reverse("passbook_flows:default-invalidation"),)
expected_url = reverse(
"passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}
)
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, expected_url)
def test_default_view_invalid_plan(self):
"""Test that ToDefaultFlow returns the expected URL (with an invalid plan)"""
flow = Flow.objects.filter(designation=FlowDesignation.INVALIDATION,).first()
plan = FlowPlan(flow_pk=flow.pk.hex + "aa", stages=[])
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.get(reverse("passbook_flows:default-invalidation"),)
expected_url = reverse(
"passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}
)
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, expected_url)

View File

@ -12,7 +12,7 @@ from passbook.flows.models import Flow, FlowDesignation, Stage
from passbook.flows.planner import FlowPlan, FlowPlanner from passbook.flows.planner import FlowPlan, FlowPlanner
from passbook.lib.config import CONFIG from passbook.lib.config import CONFIG
from passbook.lib.utils.reflection import class_to_path, path_to_class from passbook.lib.utils.reflection import class_to_path, path_to_class
from passbook.lib.utils.urls import is_url_absolute, redirect_with_qs from passbook.lib.utils.urls import redirect_with_qs
from passbook.lib.views import bad_request_message from passbook.lib.views import bad_request_message
LOGGER = get_logger() LOGGER = get_logger()
@ -59,7 +59,8 @@ class FlowExecutorView(View):
incorrect_domain_message = self._check_config_domain() incorrect_domain_message = self._check_config_domain()
if incorrect_domain_message: if incorrect_domain_message:
return incorrect_domain_message return incorrect_domain_message
return bad_request_message(self.request, str(exc)) message = exc.__doc__ if exc.__doc__ else str(exc)
return bad_request_message(self.request, message)
def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse:
# Early check if theres an active Plan for the current session # Early check if theres an active Plan for the current session
@ -128,10 +129,8 @@ class FlowExecutorView(View):
def _flow_done(self) -> HttpResponse: def _flow_done(self) -> HttpResponse:
"""User Successfully passed all stages""" """User Successfully passed all stages"""
self.cancel() self.cancel()
next_param = self.request.GET.get(NEXT_ARG_NAME, None) next_param = self.request.GET.get(NEXT_ARG_NAME, "passbook_core:overview")
if next_param and not is_url_absolute(next_param): return redirect_with_qs(next_param)
return redirect(next_param)
return redirect_with_qs("passbook_core:overview")
def stage_ok(self) -> HttpResponse: def stage_ok(self) -> HttpResponse:
"""Callback called by stages upon successful completion. """Callback called by stages upon successful completion.
@ -183,9 +182,16 @@ class ToDefaultFlow(View):
designation: Optional[FlowDesignation] = None designation: Optional[FlowDesignation] = None
def dispatch(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
if SESSION_KEY_PLAN in self.request.session:
del self.request.session[SESSION_KEY_PLAN]
flow = get_object_or_404(Flow, designation=self.designation) flow = get_object_or_404(Flow, designation=self.designation)
# If user already has a pending plan, clear it so we don't have to later.
if SESSION_KEY_PLAN in self.request.session:
plan: FlowPlan = self.request.session[SESSION_KEY_PLAN]
if plan.flow_pk != flow.pk.hex:
LOGGER.warning(
"f(def): Found existing plan for other flow, deleteing plan",
flow_slug=flow.slug,
)
del self.request.session[SESSION_KEY_PLAN]
# TODO: Get Flow depending on subdomain? # TODO: Get Flow depending on subdomain?
return redirect_with_qs( return redirect_with_qs(
"passbook_flows:flow-executor", request.GET, flow_slug=flow.slug "passbook_flows:flow-executor", request.GET, flow_slug=flow.slug

View File

@ -3,7 +3,11 @@ from urllib.parse import urlparse
from django.http import HttpResponse from django.http import HttpResponse
from django.shortcuts import redirect, reverse from django.shortcuts import redirect, reverse
from django.urls import NoReverseMatch
from django.utils.http import urlencode from django.utils.http import urlencode
from structlog import get_logger
LOGGER = get_logger()
def is_url_absolute(url): def is_url_absolute(url):
@ -13,7 +17,12 @@ def is_url_absolute(url):
def redirect_with_qs(view: str, get_query_set=None, **kwargs) -> HttpResponse: def redirect_with_qs(view: str, get_query_set=None, **kwargs) -> HttpResponse:
"""Wrapper to redirect whilst keeping GET Parameters""" """Wrapper to redirect whilst keeping GET Parameters"""
try:
target = reverse(view, kwargs=kwargs) target = reverse(view, kwargs=kwargs)
except NoReverseMatch:
LOGGER.debug("redirect target is not a valid view", view=view)
raise
else:
if get_query_set: if get_query_set:
target += "?" + urlencode(get_query_set.items()) target += "?" + urlencode(get_query_set.items())
return redirect(target) return redirect(target)

View File

@ -6,6 +6,7 @@ from django.shortcuts import reverse
from django.test import TestCase from django.test import TestCase
from passbook.core.models import Nonce, User from passbook.core.models import Nonce, User
from passbook.lib.config import CONFIG
class TestRecovery(TestCase): class TestRecovery(TestCase):
@ -16,10 +17,11 @@ class TestRecovery(TestCase):
def test_create_key(self): def test_create_key(self):
"""Test creation of a new key""" """Test creation of a new key"""
CONFIG.update_from_dict({"domain": "testserver"})
out = StringIO() out = StringIO()
self.assertEqual(len(Nonce.objects.all()), 0) self.assertEqual(len(Nonce.objects.all()), 0)
call_command("create_recovery_key", "1", self.user.username, stdout=out) call_command("create_recovery_key", "1", self.user.username, stdout=out)
self.assertIn("https://localhost/recovery/use-nonce/", out.getvalue()) self.assertIn("https://testserver/recovery/use-nonce/", out.getvalue())
self.assertEqual(len(Nonce.objects.all()), 1) self.assertEqual(len(Nonce.objects.all()), 1)
def test_recovery_view(self): def test_recovery_view(self):