diff --git a/passbook/admin/forms/policies.py b/passbook/admin/forms/policies.py index 9bca84a98..27a932eaa 100644 --- a/passbook/admin/forms/policies.py +++ b/passbook/admin/forms/policies.py @@ -1,6 +1,7 @@ """passbook administration forms""" from django import forms +from passbook.admin.fields import CodeMirrorWidget, YAMLField from passbook.core.models import User @@ -8,3 +9,4 @@ class PolicyTestForm(forms.Form): """Form to test policies against user""" user = forms.ModelChoiceField(queryset=User.objects.all()) + context = YAMLField(widget=CodeMirrorWidget()) diff --git a/passbook/admin/views/policies.py b/passbook/admin/views/policies.py index e8a009bd2..ab8278544 100644 --- a/passbook/admin/views/policies.py +++ b/passbook/admin/views/policies.py @@ -1,11 +1,15 @@ """passbook Policy administration""" +from typing import Any, Dict + from django.contrib import messages from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import ( PermissionRequiredMixin as DjangoPermissionRequiredMixin, ) from django.contrib.messages.views import SuccessMessageMixin -from django.http import Http404 +from django.db.models import QuerySet +from django.forms import Form +from django.http import Http404, HttpRequest, HttpResponse from django.urls import reverse_lazy from django.utils.translation import ugettext as _ from django.views.generic import DeleteView, FormView, ListView, UpdateView @@ -15,8 +19,8 @@ from guardian.mixins import PermissionListMixin, PermissionRequiredMixin from passbook.admin.forms.policies import PolicyTestForm from passbook.lib.utils.reflection import all_subclasses, path_to_class from passbook.lib.views import CreateAssignPermView -from passbook.policies.engine import PolicyEngine -from passbook.policies.models import Policy +from passbook.policies.models import Policy, PolicyBinding +from passbook.policies.process import PolicyProcess, PolicyRequest class PolicyListView(LoginRequiredMixin, PermissionListMixin, ListView): @@ -25,14 +29,14 @@ class PolicyListView(LoginRequiredMixin, PermissionListMixin, ListView): model = Policy permission_required = "passbook_policies.view_policy" paginate_by = 10 - ordering = "order" + ordering = "name" template_name = "administration/policy/list.html" - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: kwargs["types"] = {x.__name__: x for x in all_subclasses(Policy)} return super().get_context_data(**kwargs) - def get_queryset(self): + def get_queryset(self) -> QuerySet: return super().get_queryset().select_subclasses() @@ -51,14 +55,14 @@ class PolicyCreateView( success_url = reverse_lazy("passbook_admin:policies") success_message = _("Successfully created Policy") - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: kwargs = super().get_context_data(**kwargs) form_cls = self.get_form_class() if hasattr(form_cls, "template_name"): kwargs["base_template"] = form_cls.template_name return kwargs - def get_form_class(self): + def get_form_class(self) -> Form: policy_type = self.request.GET.get("type") try: model = next(x for x in all_subclasses(Policy) if x.__name__ == policy_type) @@ -79,19 +83,19 @@ class PolicyUpdateView( success_url = reverse_lazy("passbook_admin:policies") success_message = _("Successfully updated Policy") - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: kwargs = super().get_context_data(**kwargs) form_cls = self.get_form_class() if hasattr(form_cls, "template_name"): kwargs["base_template"] = form_cls.template_name return kwargs - def get_form_class(self): + def get_form_class(self) -> Form: form_class_path = self.get_object().form form_class = path_to_class(form_class_path) return form_class - def get_object(self, queryset=None): + def get_object(self, queryset=None) -> Policy: return ( Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first() ) @@ -109,12 +113,12 @@ class PolicyDeleteView( success_url = reverse_lazy("passbook_admin:policies") success_message = _("Successfully deleted Policy") - def get_object(self, queryset=None): + def get_object(self, queryset=None) -> Policy: return ( Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first() ) - def delete(self, request, *args, **kwargs): + def delete(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: messages.success(self.request, self.success_message) return super().delete(request, *args, **kwargs) @@ -128,26 +132,29 @@ class PolicyTestView(LoginRequiredMixin, DetailView, PermissionRequiredMixin, Fo template_name = "administration/policy/test.html" object = None - def get_object(self, queryset=None): + def get_object(self, queryset=None) -> QuerySet: return ( Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first() ) - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: kwargs["policy"] = self.get_object() return super().get_context_data(**kwargs) - def post(self, *args, **kwargs): + def post(self, *args, **kwargs) -> HttpResponse: self.object = self.get_object() return super().post(*args, **kwargs) - def form_valid(self, form): + def form_valid(self, form: PolicyTestForm) -> HttpResponse: policy = self.get_object() user = form.cleaned_data.get("user") - policy_engine = PolicyEngine([policy], user, self.request) - policy_engine.use_cache = False - policy_engine.build() - result = policy_engine.passing + + p_request = PolicyRequest(user) + p_request.http_request = self.request + p_request.context = form.cleaned_data + + proc = PolicyProcess(PolicyBinding(policy=policy), p_request, None) + result = proc.execute() if result: messages.success(self.request, _("User successfully passed policy.")) else: diff --git a/passbook/core/signals.py b/passbook/core/signals.py index c3a50b4cc..01299f90e 100644 --- a/passbook/core/signals.py +++ b/passbook/core/signals.py @@ -17,12 +17,15 @@ password_changed = Signal(providing_args=["user", "password"]) # pylint: disable=unused-argument def invalidate_policy_cache(sender, instance, **_): """Invalidate Policy cache when policy is updated""" - from passbook.policies.models import Policy + from passbook.policies.models import Policy, PolicyBinding from passbook.policies.process import cache_key if isinstance(instance, Policy): LOGGER.debug("Invalidating policy cache", policy=instance) - prefix = cache_key(instance) + "*" - keys = cache.keys(prefix) - cache.delete_many(keys) - LOGGER.debug("Deleted %d keys", len(keys)) + total = 0 + for binding in PolicyBinding.objects.filter(policy=instance): + prefix = cache_key(binding) + "*" + keys = cache.keys(prefix) + total += len(keys) + cache.delete_many(keys) + LOGGER.debug("Deleted keys", len=total) diff --git a/passbook/core/views/access.py b/passbook/core/views/access.py index ff3888645..c2e07dd19 100644 --- a/passbook/core/views/access.py +++ b/passbook/core/views/access.py @@ -1,6 +1,4 @@ """passbook access helper classes""" -from typing import List, Tuple - from django.contrib import messages from django.http import HttpRequest from django.utils.translation import gettext as _ @@ -8,6 +6,7 @@ from structlog import get_logger from passbook.core.models import Application, Provider, User from passbook.policies.engine import PolicyEngine +from passbook.policies.types import PolicyResult LOGGER = get_logger() @@ -33,9 +32,7 @@ class AccessMixin: ) raise exc - def user_has_access( - self, application: Application, user: User - ) -> Tuple[bool, List[str]]: + def user_has_access(self, application: Application, user: User) -> PolicyResult: """Check if user has access to application.""" LOGGER.debug("Checking permissions", user=user, application=application) policy_engine = PolicyEngine(application.policies.all(), user, self.request) diff --git a/passbook/flows/planner.py b/passbook/flows/planner.py index a2d0dc05f..f03f8e6f9 100644 --- a/passbook/flows/planner.py +++ b/passbook/flows/planner.py @@ -1,7 +1,7 @@ """Flows Planner""" from dataclasses import dataclass, field from time import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional from django.core.cache import cache from django.http import HttpRequest @@ -11,6 +11,7 @@ from passbook.core.models import User from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException from passbook.flows.models import Flow, Stage from passbook.policies.engine import PolicyEngine +from passbook.policies.types import PolicyResult LOGGER = get_logger() @@ -51,8 +52,8 @@ class FlowPlanner: self.use_cache = True self.flow = flow - def _check_flow_root_policies(self, request: HttpRequest) -> Tuple[bool, List[str]]: - engine = PolicyEngine(self.flow.policies.all(), request.user, request) + def _check_flow_root_policies(self, request: HttpRequest) -> PolicyResult: + engine = PolicyEngine(self.flow, request.user, request) engine.build() return engine.result @@ -64,9 +65,9 @@ class FlowPlanner: LOGGER.debug("f(plan): Starting planning process", flow=self.flow) # First off, check the flow's direct policy bindings # to make sure the user even has access to the flow - root_passing, root_passing_messages = self._check_flow_root_policies(request) - if not root_passing: - raise FlowNonApplicableException(root_passing_messages) + root_result = self._check_flow_root_policies(request) + if not root_result.passing: + raise FlowNonApplicableException(*root_result.messages) # Bit of a workaround here, if there is a pending user set in the default context # we use that user for our cache key # to make sure they don't get the generic response @@ -106,11 +107,10 @@ class FlowPlanner: .select_related() ): binding = stage.flowstagebinding_set.get(flow__pk=self.flow.pk) - engine = PolicyEngine(binding.policies.all(), user, request) + engine = PolicyEngine(binding, user, request) engine.request.context = plan.context engine.build() - passing, _ = engine.result - if passing: + if engine.passing: LOGGER.debug("f(plan): Stage passing", stage=stage, flow=self.flow) plan.stages.append(stage) end_time = time() diff --git a/passbook/flows/tests/test_planner.py b/passbook/flows/tests/test_planner.py index 79c0b6bdb..25f2bd1a7 100644 --- a/passbook/flows/tests/test_planner.py +++ b/passbook/flows/tests/test_planner.py @@ -8,9 +8,10 @@ from guardian.shortcuts import get_anonymous_user from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding from passbook.flows.planner import FlowPlanner +from passbook.policies.types import PolicyResult from passbook.stages.dummy.models import DummyStage -POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],)) +POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False)) TIME_NOW_MOCK = MagicMock(return_value=3) diff --git a/passbook/flows/tests/test_views.py b/passbook/flows/tests/test_views.py index f9979e4d9..e6a2ad20c 100644 --- a/passbook/flows/tests/test_views.py +++ b/passbook/flows/tests/test_views.py @@ -9,9 +9,10 @@ from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding from passbook.flows.planner import FlowPlan from passbook.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN from passbook.lib.config import CONFIG +from passbook.policies.types import PolicyResult from passbook.stages.dummy.models import DummyStage -POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],)) +POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False)) class TestFlowExecutor(TestCase): diff --git a/passbook/lib/models.py b/passbook/lib/models.py index d6036dd0e..0966dd8ac 100644 --- a/passbook/lib/models.py +++ b/passbook/lib/models.py @@ -1,5 +1,6 @@ """Generic models""" from django.db import models +from model_utils.managers import InheritanceManager class CreatedUpdatedModel(models.Model): @@ -10,3 +11,27 @@ class CreatedUpdatedModel(models.Model): class Meta: abstract = True + + +class InheritanceAutoManager(InheritanceManager): + """Object manager which automatically selects the subclass""" + + def get_queryset(self): + return super().get_queryset().select_subclasses() + + +class InheritanceForwardManyToOneDescriptor( + models.fields.related.ForwardManyToOneDescriptor +): + """Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager.""" + + def get_queryset(self, **hints): + return self.field.remote_field.model.objects.db_manager( + hints=hints + ).select_subclasses() + + +class InheritanceForeignKey(models.ForeignKey): + """Custom ForeignKey that uses InheritanceForwardManyToOneDescriptor""" + + forward_related_accessor_class = InheritanceForwardManyToOneDescriptor diff --git a/passbook/policies/api.py b/passbook/policies/api.py index ebae9a3a0..27fd129ac 100644 --- a/passbook/policies/api.py +++ b/passbook/policies/api.py @@ -12,7 +12,7 @@ class PolicyBindingSerializer(ModelSerializer): class Meta: model = PolicyBinding - fields = ["policy", "target", "enabled", "order"] + fields = ["policy", "target", "enabled", "order", "timeout"] class PolicyBindingViewSet(ModelViewSet): diff --git a/passbook/policies/engine.py b/passbook/policies/engine.py index 24aaae59f..143ad6473 100644 --- a/passbook/policies/engine.py +++ b/passbook/policies/engine.py @@ -1,14 +1,14 @@ """passbook policy engine""" from multiprocessing import Pipe, set_start_method from multiprocessing.connection import Connection -from typing import List, Optional, Tuple +from typing import List, Optional from django.core.cache import cache from django.http import HttpRequest from structlog import get_logger from passbook.core.models import User -from passbook.policies.models import Policy +from passbook.policies.models import Policy, PolicyBinding, PolicyBindingModel from passbook.policies.process import PolicyProcess, cache_key from passbook.policies.types import PolicyRequest, PolicyResult @@ -24,12 +24,14 @@ class PolicyProcessInfo: process: PolicyProcess connection: Connection result: Optional[PolicyResult] - policy: Policy + binding: PolicyBinding - def __init__(self, process: PolicyProcess, connection: Connection, policy: Policy): + def __init__( + self, process: PolicyProcess, connection: Connection, binding: PolicyBinding + ): self.process = process self.connection = connection - self.policy = policy + self.binding = binding self.result = None @@ -37,54 +39,64 @@ class PolicyEngine: """Orchestrate policy checking, launch tasks and return result""" use_cache: bool = True - policies: List[Policy] = [] request: PolicyRequest + __pbm: PolicyBindingModel __cached_policies: List[PolicyResult] __processes: List[PolicyProcessInfo] - def __init__(self, policies, user: User, request: HttpRequest = None): - self.policies = policies + def __init__( + self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None + ): + if not isinstance(pbm, PolicyBindingModel): + raise ValueError(f"{pbm} is not instance of PolicyBindingModel") + self.__pbm = pbm self.request = PolicyRequest(user) if request: self.request.http_request = request self.__cached_policies = [] self.__processes = [] - def _select_subclasses(self) -> List[Policy]: + def _iter_bindings(self) -> List[PolicyBinding]: """Make sure all Policies are their respective classes""" - return ( - Policy.objects.filter(pk__in=[x.pk for x in self.policies]) - .select_subclasses() - .order_by("order") + return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by( + "order" ) + def _check_policy_type(self, policy: Policy): + """Check policy type, make sure it's not the root class as that has no logic implemented""" + # policy_type = type(policy) + if policy.__class__ == Policy: + raise TypeError(f"Policy '{policy}' is root type") + def build(self) -> "PolicyEngine": """Build task group""" - for policy in self._select_subclasses(): - cached_policy = cache.get(cache_key(policy, self.request.user), None) + for binding in self._iter_bindings(): + self._check_policy_type(binding.policy) + policy = binding.policy + cached_policy = cache.get(cache_key(binding, self.request.user), None) if cached_policy and self.use_cache: LOGGER.debug("P_ENG: Taking result from cache", policy=policy) self.__cached_policies.append(cached_policy) continue LOGGER.debug("P_ENG: Evaluating policy", policy=policy) our_end, task_end = Pipe(False) - task = PolicyProcess(policy, self.request, task_end) + task = PolicyProcess(binding, self.request, task_end) LOGGER.debug("P_ENG: Starting Process", policy=policy) task.start() self.__processes.append( - PolicyProcessInfo(process=task, connection=our_end, policy=policy) + PolicyProcessInfo(process=task, connection=our_end, binding=binding) ) # If all policies are cached, we have an empty list here. for proc_info in self.__processes: - proc_info.process.join(proc_info.policy.timeout) + proc_info.process.join(proc_info.binding.timeout) # Only call .recv() if no result is saved, otherwise we just deadlock here if not proc_info.result: proc_info.result = proc_info.connection.recv() return self @property - def result(self) -> Tuple[bool, List[str]]: + def result(self) -> PolicyResult: """Get policy-checking result""" messages: List[str] = [] process_results: List[PolicyResult] = [ @@ -95,10 +107,10 @@ class PolicyEngine: if result.messages: messages += result.messages if not result.passing: - return False, messages - return True, messages + return PolicyResult(False, *messages) + return PolicyResult(True, *messages) @property def passing(self) -> bool: """Only get true/false if user passes""" - return self.result[0] + return self.result.passing diff --git a/passbook/policies/expression/evaluator.py b/passbook/policies/expression/evaluator.py index 2b31f4671..b2120bb74 100644 --- a/passbook/policies/expression/evaluator.py +++ b/passbook/policies/expression/evaluator.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional from django.core.exceptions import ValidationError from jinja2 import Undefined -from jinja2.exceptions import TemplateSyntaxError, UndefinedError +from jinja2.exceptions import TemplateSyntaxError from jinja2.nativetypes import NativeEnvironment from requests import Session from structlog import get_logger @@ -90,7 +90,8 @@ class Evaluator: if result: return PolicyResult(bool(result)) return PolicyResult(False) - except UndefinedError as exc: + except Exception as exc: # pylint: disable=broad-except + LOGGER.warning("Expression error", exc=exc) return PolicyResult(False, str(exc)) def validate(self, expression: str): diff --git a/passbook/policies/forms.py b/passbook/policies/forms.py index 8bcdbb9e0..0378e2c58 100644 --- a/passbook/policies/forms.py +++ b/passbook/policies/forms.py @@ -3,8 +3,8 @@ from django import forms from passbook.policies.models import PolicyBinding, PolicyBindingModel -GENERAL_FIELDS = ["name", "negate", "order", "timeout"] -GENERAL_SERIALIZER_FIELDS = ["pk", "name", "negate", "order", "timeout"] +GENERAL_FIELDS = ["name"] +GENERAL_SERIALIZER_FIELDS = ["pk", "name"] class PolicyBindingForm(forms.ModelForm): @@ -18,9 +18,4 @@ class PolicyBindingForm(forms.ModelForm): class Meta: model = PolicyBinding - fields = [ - "enabled", - "policy", - "target", - "order", - ] + fields = ["enabled", "policy", "target", "order", "timeout"] diff --git a/passbook/policies/migrations/0002_auto_20200528_1647.py b/passbook/policies/migrations/0002_auto_20200528_1647.py new file mode 100644 index 000000000..b43a2f732 --- /dev/null +++ b/passbook/policies/migrations/0002_auto_20200528_1647.py @@ -0,0 +1,58 @@ +# Generated by Django 3.0.6 on 2020-05-28 16:47 + +import django.db.models.deletion +from django.db import migrations, models + +import passbook.lib.models + + +class Migration(migrations.Migration): + + dependencies = [ + ("passbook_policies", "0001_initial"), + ] + + operations = [ + migrations.AlterModelOptions( + name="policy", + options={ + "base_manager_name": "objects", + "verbose_name": "Policy", + "verbose_name_plural": "Policies", + }, + ), + migrations.RemoveField(model_name="policy", name="negate",), + migrations.RemoveField(model_name="policy", name="order",), + migrations.RemoveField(model_name="policy", name="timeout",), + migrations.AddField( + model_name="policybinding", + name="negate", + field=models.BooleanField( + default=False, + help_text="Negates the outcome of the policy. Messages are unaffected.", + ), + ), + migrations.AddField( + model_name="policybinding", + name="timeout", + field=models.IntegerField( + default=30, + help_text="Timeout after which Policy execution is terminated.", + ), + ), + migrations.AlterField( + model_name="policybinding", name="order", field=models.IntegerField(), + ), + migrations.AlterField( + model_name="policybinding", + name="policy", + field=passbook.lib.models.InheritanceForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="passbook_policies.Policy", + ), + ), + migrations.AlterUniqueTogether( + name="policybinding", unique_together={("policy", "target", "order")}, + ), + ] diff --git a/passbook/policies/models.py b/passbook/policies/models.py index b47dc9ecd..826938f51 100644 --- a/passbook/policies/models.py +++ b/passbook/policies/models.py @@ -5,7 +5,11 @@ from django.db import models from django.utils.translation import gettext_lazy as _ from model_utils.managers import InheritanceManager -from passbook.lib.models import CreatedUpdatedModel +from passbook.lib.models import ( + CreatedUpdatedModel, + InheritanceAutoManager, + InheritanceForeignKey, +) from passbook.policies.exceptions import PolicyException from passbook.policies.types import PolicyRequest, PolicyResult @@ -22,7 +26,6 @@ class PolicyBindingModel(models.Model): objects = InheritanceManager() class Meta: - verbose_name = _("Policy Binding Model") verbose_name_plural = _("Policy Binding Models") @@ -36,13 +39,19 @@ class PolicyBinding(models.Model): enabled = models.BooleanField(default=True) - policy = models.ForeignKey("Policy", on_delete=models.CASCADE, related_name="+") + policy = InheritanceForeignKey("Policy", on_delete=models.CASCADE, related_name="+") target = models.ForeignKey( PolicyBindingModel, on_delete=models.CASCADE, related_name="+" ) + negate = models.BooleanField( + default=False, + help_text=_("Negates the outcome of the policy. Messages are unaffected."), + ) + timeout = models.IntegerField( + default=30, help_text=_("Timeout after which Policy execution is terminated.") + ) - # default value and non-unique for compatibility - order = models.IntegerField(default=0) + order = models.IntegerField() def __str__(self) -> str: return f"PolicyBinding policy={self.policy} target={self.target} order={self.order}" @@ -51,6 +60,7 @@ class PolicyBinding(models.Model): verbose_name = _("Policy Binding") verbose_name_plural = _("Policy Bindings") + unique_together = ("policy", "target", "order") class Policy(CreatedUpdatedModel): @@ -60,11 +70,8 @@ class Policy(CreatedUpdatedModel): policy_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) name = models.TextField(blank=True, null=True) - negate = models.BooleanField(default=False) - order = models.IntegerField(default=0) - timeout = models.IntegerField(default=30) - objects = InheritanceManager() + objects = InheritanceAutoManager() def __str__(self): return f"Policy {self.name}" @@ -72,3 +79,9 @@ class Policy(CreatedUpdatedModel): def passes(self, request: PolicyRequest) -> PolicyResult: """Check if user instance passes this policy""" raise PolicyException() + + class Meta: + base_manager_name = "objects" + + verbose_name = _("Policy") + verbose_name_plural = _("Policies") diff --git a/passbook/policies/process.py b/passbook/policies/process.py index f6c2a1265..1fb906c9f 100644 --- a/passbook/policies/process.py +++ b/passbook/policies/process.py @@ -8,15 +8,15 @@ from structlog import get_logger from passbook.core.models import User from passbook.policies.exceptions import PolicyException -from passbook.policies.models import Policy +from passbook.policies.models import PolicyBinding from passbook.policies.types import PolicyRequest, PolicyResult LOGGER = get_logger() -def cache_key(policy: Policy, user: Optional[User] = None) -> str: +def cache_key(binding: PolicyBinding, user: Optional[User] = None) -> str: """Generate Cache key for policy""" - prefix = f"policy_{policy.pk}" + prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}" if user: prefix += f"#{user.pk}" return prefix @@ -26,40 +26,50 @@ class PolicyProcess(Process): """Evaluate a single policy within a seprate process""" connection: Connection - policy: Policy + binding: PolicyBinding request: PolicyRequest - def __init__(self, policy: Policy, request: PolicyRequest, connection: Connection): + def __init__( + self, + binding: PolicyBinding, + request: PolicyRequest, + connection: Optional[Connection], + ): super().__init__() - self.policy = policy + self.binding = binding self.request = request - self.connection = connection + if connection: + self.connection = connection - def run(self): - """Task wrapper to run policy checking""" + def execute(self) -> PolicyResult: + """Run actual policy, returns result""" LOGGER.debug( "P_ENG(proc): Running policy", - policy=self.policy, + policy=self.binding.policy, user=self.request.user, process="PolicyProcess", ) try: - policy_result = self.policy.passes(self.request) + policy_result = self.binding.policy.passes(self.request) except PolicyException as exc: LOGGER.debug("P_ENG(proc): error", exc=exc) policy_result = PolicyResult(False, str(exc)) # Invert result if policy.negate is set - if self.policy.negate: + if self.binding.negate: policy_result.passing = not policy_result.passing LOGGER.debug( "P_ENG(proc): Finished", - policy=self.policy, + policy=self.binding.policy, result=policy_result, process="PolicyProcess", passing=policy_result.passing, user=self.request.user, ) - key = cache_key(self.policy, self.request.user) + key = cache_key(self.binding, self.request.user) cache.set(key, policy_result) LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key) - self.connection.send(policy_result) + return policy_result + + def run(self): + """Task wrapper to run policy checking""" + self.connection.send(self.execute()) diff --git a/passbook/policies/tests/test_engine.py b/passbook/policies/tests/test_engine.py index 0f4c63a83..05537e432 100644 --- a/passbook/policies/tests/test_engine.py +++ b/passbook/policies/tests/test_engine.py @@ -5,7 +5,8 @@ from django.test import TestCase from passbook.core.models import User from passbook.policies.dummy.models import DummyPolicy from passbook.policies.engine import PolicyEngine -from passbook.policies.models import Policy +from passbook.policies.expression.models import ExpressionPolicy +from passbook.policies.models import Policy, PolicyBinding, PolicyBindingModel class PolicyTestEngine(TestCase): @@ -20,40 +21,64 @@ class PolicyTestEngine(TestCase): self.policy_true = DummyPolicy.objects.create( result=True, wait_min=0, wait_max=1 ) - self.policy_negate = DummyPolicy.objects.create( - negate=True, result=True, wait_min=0, wait_max=1 + self.policy_wrong_type = Policy.objects.create(name="wrong_type") + self.policy_raises = ExpressionPolicy.objects.create( + name="raises", expression="{{ 0/0 }}" ) - self.policy_raises = Policy.objects.create(name="raises") def test_engine_empty(self): """Ensure empty policy list passes""" - engine = PolicyEngine([], self.user) - self.assertEqual(engine.build().passing, True) + pbm = PolicyBindingModel.objects.create() + engine = PolicyEngine(pbm, self.user) + result = engine.build().result + self.assertEqual(result.passing, True) + self.assertEqual(result.messages, ()) def test_engine(self): """Ensure all policies passes (Mix of false and true -> false)""" - engine = PolicyEngine( - DummyPolicy.objects.filter(negate__exact=False), self.user - ) - self.assertEqual(engine.build().passing, False) + pbm = PolicyBindingModel.objects.create() + PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) + PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1) + engine = PolicyEngine(pbm, self.user) + result = engine.build().result + self.assertEqual(result.passing, False) + self.assertEqual(result.messages, ("dummy",)) def test_engine_negate(self): """Test negate flag""" - engine = PolicyEngine(DummyPolicy.objects.filter(negate__exact=True), self.user) - self.assertEqual(engine.build().passing, False) + pbm = PolicyBindingModel.objects.create() + PolicyBinding.objects.create( + target=pbm, policy=self.policy_true, negate=True, order=0 + ) + engine = PolicyEngine(pbm, self.user) + result = engine.build().result + self.assertEqual(result.passing, False) + self.assertEqual(result.messages, ("dummy",)) def test_engine_policy_error(self): - """Test negate flag""" - engine = PolicyEngine(Policy.objects.filter(name="raises"), self.user) - self.assertEqual(engine.build().passing, False) + """Test policy raising an error flag""" + pbm = PolicyBindingModel.objects.create() + PolicyBinding.objects.create(target=pbm, policy=self.policy_raises, order=0) + engine = PolicyEngine(pbm, self.user) + result = engine.build().result + self.assertEqual(result.passing, False) + self.assertEqual(result.messages, ("division by zero",)) + + def test_engine_policy_type(self): + """Test invalid policy type""" + pbm = PolicyBindingModel.objects.create() + PolicyBinding.objects.create(target=pbm, policy=self.policy_wrong_type, order=0) + with self.assertRaises(TypeError): + engine = PolicyEngine(pbm, self.user) + engine.build() def test_engine_cache(self): """Ensure empty policy list passes""" - engine = PolicyEngine( - DummyPolicy.objects.filter(negate__exact=False), self.user - ) + pbm = PolicyBindingModel.objects.create() + PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) + engine = PolicyEngine(pbm, self.user) self.assertEqual(len(cache.keys("policy_*")), 0) self.assertEqual(engine.build().passing, False) - self.assertEqual(len(cache.keys("policy_*")), 2) + self.assertEqual(len(cache.keys("policy_*")), 1) self.assertEqual(engine.build().passing, False) - self.assertEqual(len(cache.keys("policy_*")), 2) + self.assertEqual(len(cache.keys("policy_*")), 1) diff --git a/passbook/providers/oauth/views/oauth2.py b/passbook/providers/oauth/views/oauth2.py index 96a0eaaa2..208dd99ac 100644 --- a/passbook/providers/oauth/views/oauth2.py +++ b/passbook/providers/oauth/views/oauth2.py @@ -50,9 +50,9 @@ class PassbookAuthorizationView(AccessMixin, AuthorizationView): provider.save() self._application = application # Check permissions - passing, policy_messages = self.user_has_access(self._application, request.user) - if not passing: - for policy_message in policy_messages: + result = self.user_has_access(self._application, request.user) + if not result.passing: + for policy_message in result.messages: messages.error(request, policy_message) return redirect("passbook_providers_oauth:oauth2-permission-denied") # Some clients don't pass response_type, so we default to code diff --git a/passbook/providers/oidc/auth.py b/passbook/providers/oidc/auth.py index 334607afa..91a0b9dcf 100644 --- a/passbook/providers/oidc/auth.py +++ b/passbook/providers/oidc/auth.py @@ -18,7 +18,7 @@ LOGGER = get_logger() def client_related_provider(client: Client) -> Optional[Provider]: """Lookup related Application from Client""" # because oidc_provider is also used by app_gw, we can't be - # sure an OpenIDPRovider instance exists. hence we look through all related models + # sure an OpenIDProvider instance exists. hence we look through all related models # and choose the one that inherits from Provider, which is guaranteed to # have the application property collector = Collector(using="default") @@ -50,9 +50,9 @@ def check_permissions( policy_engine.build() # Check permissions - passing, policy_messages = policy_engine.result - if not passing: - for policy_message in policy_messages: + result = policy_engine.result + if not result.passing: + for policy_message in result.messages: messages.error(request, policy_message) return redirect("passbook_providers_oauth:oauth2-permission-denied") diff --git a/passbook/stages/prompt/forms.py b/passbook/stages/prompt/forms.py index 14830f36d..d59517e00 100644 --- a/passbook/stages/prompt/forms.py +++ b/passbook/stages/prompt/forms.py @@ -55,9 +55,9 @@ class PromptForm(forms.Form): def clean(self): cleaned_data = super().clean() user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) - engine = PolicyEngine(self.stage.policies.all(), user) + engine = PolicyEngine(self.stage, user) engine.request.context = cleaned_data engine.build() - passing, messages = engine.result - if not passing: - raise forms.ValidationError(messages) + result = engine.result + if not result.passing: + raise forms.ValidationError(result.messages) diff --git a/passbook/stages/prompt/tests.py b/passbook/stages/prompt/tests.py index d4d2e108d..ee63c8566 100644 --- a/passbook/stages/prompt/tests.py +++ b/passbook/stages/prompt/tests.py @@ -139,7 +139,7 @@ class TestPromptStage(TestCase): expr_policy = ExpressionPolicy.objects.create( name="validate-form", expression=expr ) - PolicyBinding.objects.create(policy=expr_policy, target=self.stage) + PolicyBinding.objects.create(policy=expr_policy, target=self.stage, order=0) form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data) self.assertEqual(form.is_valid(), True) return form @@ -151,7 +151,7 @@ class TestPromptStage(TestCase): expr_policy = ExpressionPolicy.objects.create( name="validate-form", expression=expr ) - PolicyBinding.objects.create(policy=expr_policy, target=self.stage) + PolicyBinding.objects.create(policy=expr_policy, target=self.stage, order=0) form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data) self.assertEqual(form.is_valid(), False) return form diff --git a/swagger.yaml b/swagger.yaml index ce498dd6f..54beb80bb 100755 --- a/swagger.yaml +++ b/swagger.yaml @@ -837,7 +837,7 @@ paths: parameters: - name: policy_uuid in: path - description: A UUID string identifying this policy. + description: A UUID string identifying this Policy. required: true type: string format: uuid @@ -5079,19 +5079,6 @@ definitions: title: Name type: string x-nullable: true - negate: - title: Negate - type: boolean - order: - title: Order - type: integer - maximum: 2147483647 - minimum: -2147483648 - timeout: - title: Timeout - type: integer - maximum: 2147483647 - minimum: -2147483648 __type__: title: 'type ' type: string @@ -5100,6 +5087,7 @@ definitions: required: - policy - target + - order type: object properties: policy: @@ -5118,6 +5106,12 @@ definitions: type: integer maximum: 2147483647 minimum: -2147483648 + timeout: + title: Timeout + description: Timeout after which Policy execution is terminated. + type: integer + maximum: 2147483647 + minimum: -2147483648 DummyPolicy: type: object properties: @@ -5130,19 +5124,6 @@ definitions: title: Name type: string x-nullable: true - negate: - title: Negate - type: boolean - order: - title: Order - type: integer - maximum: 2147483647 - minimum: -2147483648 - timeout: - title: Timeout - type: integer - maximum: 2147483647 - minimum: -2147483648 result: title: Result type: boolean @@ -5170,19 +5151,6 @@ definitions: title: Name type: string x-nullable: true - negate: - title: Negate - type: boolean - order: - title: Order - type: integer - maximum: 2147483647 - minimum: -2147483648 - timeout: - title: Timeout - type: integer - maximum: 2147483647 - minimum: -2147483648 expression: title: Expression type: string @@ -5199,19 +5167,6 @@ definitions: title: Name type: string x-nullable: true - negate: - title: Negate - type: boolean - order: - title: Order - type: integer - maximum: 2147483647 - minimum: -2147483648 - timeout: - title: Timeout - type: integer - maximum: 2147483647 - minimum: -2147483648 allowed_count: title: Allowed count type: integer @@ -5231,19 +5186,6 @@ definitions: title: Name type: string x-nullable: true - negate: - title: Negate - type: boolean - order: - title: Order - type: integer - maximum: 2147483647 - minimum: -2147483648 - timeout: - title: Timeout - type: integer - maximum: 2147483647 - minimum: -2147483648 amount_uppercase: title: Amount uppercase type: integer @@ -5286,19 +5228,6 @@ definitions: title: Name type: string x-nullable: true - negate: - title: Negate - type: boolean - order: - title: Order - type: integer - maximum: 2147483647 - minimum: -2147483648 - timeout: - title: Timeout - type: integer - maximum: 2147483647 - minimum: -2147483648 days: title: Days type: integer @@ -5319,19 +5248,6 @@ definitions: title: Name type: string x-nullable: true - negate: - title: Negate - type: boolean - order: - title: Order - type: integer - maximum: 2147483647 - minimum: -2147483648 - timeout: - title: Timeout - type: integer - maximum: 2147483647 - minimum: -2147483648 check_ip: title: Check ip type: boolean