diff --git a/passbook/flows/planner.py b/passbook/flows/planner.py index 5ed272888..a2e654bf7 100644 --- a/passbook/flows/planner.py +++ b/passbook/flows/planner.py @@ -1,11 +1,13 @@ """Flows Planner""" from dataclasses import dataclass, field from time import time -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple +from django.core.cache import cache from django.http import HttpRequest from structlog import get_logger +from passbook.core.models import User from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException from passbook.flows.models import Flow, Stage from passbook.policies.engine import PolicyEngine @@ -16,6 +18,14 @@ PLAN_CONTEXT_PENDING_USER = "pending_user" PLAN_CONTEXT_SSO = "is_sso" +def cache_key(flow: Flow, user: Optional[User] = None) -> str: + """Generate Cache key for flow""" + prefix = f"flow_{flow.pk}" + if user: + prefix += f"#{user.pk}" + return prefix + + @dataclass class FlowPlan: """This data-class is the output of a FlowPlanner. It holds a flat list @@ -34,9 +44,11 @@ class FlowPlanner: """Execute all policies to plan out a flat list of all Stages that should be applied.""" + use_cache: bool flow: Flow def __init__(self, flow: Flow): + self.use_cache = True self.flow = flow def _check_flow_root_policies(self, request: HttpRequest) -> Tuple[bool, List[str]]: @@ -48,13 +60,17 @@ class FlowPlanner: """Check each of the flows' policies, check policies for each stage with PolicyBinding and return ordered list""" LOGGER.debug("f(plan): Starting planning process", flow=self.flow) - start_time = time() - plan = FlowPlan(flow_pk=self.flow.pk.hex) # First off, check the flow's direct policy bindings # to make sure the user even has access to the flow root_passing, root_passing_messages = self._check_flow_root_policies(request) if not root_passing: raise FlowNonApplicableException(root_passing_messages) + cached_plan = cache.get(cache_key(self.flow, request.user), None) + if cached_plan and self.use_cache: + LOGGER.debug("f(plan): Taking plan from cache", flow=self.flow) + return cached_plan + start_time = time() + plan = FlowPlan(flow_pk=self.flow.pk.hex) # Check Flow policies for stage in ( self.flow.stages.order_by("flowstagebinding__order") @@ -66,7 +82,7 @@ class FlowPlanner: engine.build() passing, _ = engine.result if passing: - LOGGER.debug("f(plan): Stage passing", stage=stage) + LOGGER.debug("f(plan): Stage passing", stage=stage, flow=self.flow) plan.stages.append(stage) end_time = time() LOGGER.debug( @@ -74,6 +90,7 @@ class FlowPlanner: flow=self.flow, duration_s=end_time - start_time, ) + cache.set(cache_key(self.flow, request.user), plan) if not plan.stages: raise EmptyFlowException() return plan diff --git a/passbook/flows/tests/__init__.py b/passbook/flows/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/passbook/flows/tests/test_planner.py b/passbook/flows/tests/test_planner.py new file mode 100644 index 000000000..79c0b6bdb --- /dev/null +++ b/passbook/flows/tests/test_planner.py @@ -0,0 +1,82 @@ +"""flow planner tests""" +from unittest.mock import MagicMock, patch + +from django.shortcuts import reverse +from django.test import RequestFactory, TestCase +from guardian.shortcuts import get_anonymous_user + +from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException +from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding +from passbook.flows.planner import FlowPlanner +from passbook.stages.dummy.models import DummyStage + +POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],)) +TIME_NOW_MOCK = MagicMock(return_value=3) + + +class TestFlowPlanner(TestCase): + """Test planner logic""" + + def setUp(self): + self.request_factory = RequestFactory() + + def test_empty_plan(self): + """Test that empty plan raises exception""" + flow = Flow.objects.create( + name="test-empty", + slug="test-empty", + designation=FlowDesignation.AUTHENTICATION, + ) + request = self.request_factory.get( + reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + ) + request.user = get_anonymous_user() + + with self.assertRaises(EmptyFlowException): + planner = FlowPlanner(flow) + planner.plan(request) + + @patch( + "passbook.flows.planner.FlowPlanner._check_flow_root_policies", + POLICY_RESULT_MOCK, + ) + def test_non_applicable_plan(self): + """Test that empty plan raises exception""" + flow = Flow.objects.create( + name="test-empty", + slug="test-empty", + designation=FlowDesignation.AUTHENTICATION, + ) + request = self.request_factory.get( + reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + ) + request.user = get_anonymous_user() + + with self.assertRaises(FlowNonApplicableException): + planner = FlowPlanner(flow) + planner.plan(request) + + @patch("passbook.flows.planner.time", TIME_NOW_MOCK) + def test_planner_cache(self): + """Test planner cache""" + flow = Flow.objects.create( + name="test-cache", + slug="test-cache", + designation=FlowDesignation.AUTHENTICATION, + ) + FlowStageBinding.objects.create( + flow=flow, stage=DummyStage.objects.create(name="dummy"), order=0 + ) + request = self.request_factory.get( + reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + ) + request.user = get_anonymous_user() + + planner = FlowPlanner(flow) + planner.plan(request) + self.assertEqual(TIME_NOW_MOCK.call_count, 2) # Start and end + planner = FlowPlanner(flow) + planner.plan(request) + self.assertEqual( + TIME_NOW_MOCK.call_count, 2 + ) # When taking from cache, time is not measured