policies/*: remove Policy.negate, order, timeout (#39)

policies: rewrite engine to use PolicyBinding for order/negate/timeout
policies: rewrite engine to use PolicyResult instead of tuple
This commit is contained in:
Jens L 2020-05-28 21:45:54 +02:00 committed by GitHub
parent fdfc6472d2
commit df8995deed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 290 additions and 224 deletions

View File

@ -1,6 +1,7 @@
"""passbook administration forms""" """passbook administration forms"""
from django import forms from django import forms
from passbook.admin.fields import CodeMirrorWidget, YAMLField
from passbook.core.models import User from passbook.core.models import User
@ -8,3 +9,4 @@ class PolicyTestForm(forms.Form):
"""Form to test policies against user""" """Form to test policies against user"""
user = forms.ModelChoiceField(queryset=User.objects.all()) user = forms.ModelChoiceField(queryset=User.objects.all())
context = YAMLField(widget=CodeMirrorWidget())

View File

@ -1,11 +1,15 @@
"""passbook Policy administration""" """passbook Policy administration"""
from typing import Any, Dict
from django.contrib import messages from django.contrib import messages
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import ( from django.contrib.auth.mixins import (
PermissionRequiredMixin as DjangoPermissionRequiredMixin, PermissionRequiredMixin as DjangoPermissionRequiredMixin,
) )
from django.contrib.messages.views import SuccessMessageMixin from django.contrib.messages.views import SuccessMessageMixin
from django.http import Http404 from django.db.models import QuerySet
from django.forms import Form
from django.http import Http404, HttpRequest, HttpResponse
from django.urls import reverse_lazy from django.urls import reverse_lazy
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from django.views.generic import DeleteView, FormView, ListView, UpdateView from django.views.generic import DeleteView, FormView, ListView, UpdateView
@ -15,8 +19,8 @@ from guardian.mixins import PermissionListMixin, PermissionRequiredMixin
from passbook.admin.forms.policies import PolicyTestForm from passbook.admin.forms.policies import PolicyTestForm
from passbook.lib.utils.reflection import all_subclasses, path_to_class from passbook.lib.utils.reflection import all_subclasses, path_to_class
from passbook.lib.views import CreateAssignPermView from passbook.lib.views import CreateAssignPermView
from passbook.policies.engine import PolicyEngine from passbook.policies.models import Policy, PolicyBinding
from passbook.policies.models import Policy from passbook.policies.process import PolicyProcess, PolicyRequest
class PolicyListView(LoginRequiredMixin, PermissionListMixin, ListView): class PolicyListView(LoginRequiredMixin, PermissionListMixin, ListView):
@ -25,14 +29,14 @@ class PolicyListView(LoginRequiredMixin, PermissionListMixin, ListView):
model = Policy model = Policy
permission_required = "passbook_policies.view_policy" permission_required = "passbook_policies.view_policy"
paginate_by = 10 paginate_by = 10
ordering = "order" ordering = "name"
template_name = "administration/policy/list.html" template_name = "administration/policy/list.html"
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
kwargs["types"] = {x.__name__: x for x in all_subclasses(Policy)} kwargs["types"] = {x.__name__: x for x in all_subclasses(Policy)}
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
def get_queryset(self): def get_queryset(self) -> QuerySet:
return super().get_queryset().select_subclasses() return super().get_queryset().select_subclasses()
@ -51,14 +55,14 @@ class PolicyCreateView(
success_url = reverse_lazy("passbook_admin:policies") success_url = reverse_lazy("passbook_admin:policies")
success_message = _("Successfully created Policy") success_message = _("Successfully created Policy")
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
kwargs = super().get_context_data(**kwargs) kwargs = super().get_context_data(**kwargs)
form_cls = self.get_form_class() form_cls = self.get_form_class()
if hasattr(form_cls, "template_name"): if hasattr(form_cls, "template_name"):
kwargs["base_template"] = form_cls.template_name kwargs["base_template"] = form_cls.template_name
return kwargs return kwargs
def get_form_class(self): def get_form_class(self) -> Form:
policy_type = self.request.GET.get("type") policy_type = self.request.GET.get("type")
try: try:
model = next(x for x in all_subclasses(Policy) if x.__name__ == policy_type) model = next(x for x in all_subclasses(Policy) if x.__name__ == policy_type)
@ -79,19 +83,19 @@ class PolicyUpdateView(
success_url = reverse_lazy("passbook_admin:policies") success_url = reverse_lazy("passbook_admin:policies")
success_message = _("Successfully updated Policy") success_message = _("Successfully updated Policy")
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
kwargs = super().get_context_data(**kwargs) kwargs = super().get_context_data(**kwargs)
form_cls = self.get_form_class() form_cls = self.get_form_class()
if hasattr(form_cls, "template_name"): if hasattr(form_cls, "template_name"):
kwargs["base_template"] = form_cls.template_name kwargs["base_template"] = form_cls.template_name
return kwargs return kwargs
def get_form_class(self): def get_form_class(self) -> Form:
form_class_path = self.get_object().form form_class_path = self.get_object().form
form_class = path_to_class(form_class_path) form_class = path_to_class(form_class_path)
return form_class return form_class
def get_object(self, queryset=None): def get_object(self, queryset=None) -> Policy:
return ( return (
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first() Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
) )
@ -109,12 +113,12 @@ class PolicyDeleteView(
success_url = reverse_lazy("passbook_admin:policies") success_url = reverse_lazy("passbook_admin:policies")
success_message = _("Successfully deleted Policy") success_message = _("Successfully deleted Policy")
def get_object(self, queryset=None): def get_object(self, queryset=None) -> Policy:
return ( return (
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first() Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
) )
def delete(self, request, *args, **kwargs): def delete(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
messages.success(self.request, self.success_message) messages.success(self.request, self.success_message)
return super().delete(request, *args, **kwargs) return super().delete(request, *args, **kwargs)
@ -128,26 +132,29 @@ class PolicyTestView(LoginRequiredMixin, DetailView, PermissionRequiredMixin, Fo
template_name = "administration/policy/test.html" template_name = "administration/policy/test.html"
object = None object = None
def get_object(self, queryset=None): def get_object(self, queryset=None) -> QuerySet:
return ( return (
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first() Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
) )
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
kwargs["policy"] = self.get_object() kwargs["policy"] = self.get_object()
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
def post(self, *args, **kwargs): def post(self, *args, **kwargs) -> HttpResponse:
self.object = self.get_object() self.object = self.get_object()
return super().post(*args, **kwargs) return super().post(*args, **kwargs)
def form_valid(self, form): def form_valid(self, form: PolicyTestForm) -> HttpResponse:
policy = self.get_object() policy = self.get_object()
user = form.cleaned_data.get("user") user = form.cleaned_data.get("user")
policy_engine = PolicyEngine([policy], user, self.request)
policy_engine.use_cache = False p_request = PolicyRequest(user)
policy_engine.build() p_request.http_request = self.request
result = policy_engine.passing p_request.context = form.cleaned_data
proc = PolicyProcess(PolicyBinding(policy=policy), p_request, None)
result = proc.execute()
if result: if result:
messages.success(self.request, _("User successfully passed policy.")) messages.success(self.request, _("User successfully passed policy."))
else: else:

View File

@ -17,12 +17,15 @@ password_changed = Signal(providing_args=["user", "password"])
# pylint: disable=unused-argument # pylint: disable=unused-argument
def invalidate_policy_cache(sender, instance, **_): def invalidate_policy_cache(sender, instance, **_):
"""Invalidate Policy cache when policy is updated""" """Invalidate Policy cache when policy is updated"""
from passbook.policies.models import Policy from passbook.policies.models import Policy, PolicyBinding
from passbook.policies.process import cache_key from passbook.policies.process import cache_key
if isinstance(instance, Policy): if isinstance(instance, Policy):
LOGGER.debug("Invalidating policy cache", policy=instance) LOGGER.debug("Invalidating policy cache", policy=instance)
prefix = cache_key(instance) + "*" total = 0
keys = cache.keys(prefix) for binding in PolicyBinding.objects.filter(policy=instance):
cache.delete_many(keys) prefix = cache_key(binding) + "*"
LOGGER.debug("Deleted %d keys", len(keys)) keys = cache.keys(prefix)
total += len(keys)
cache.delete_many(keys)
LOGGER.debug("Deleted keys", len=total)

View File

@ -1,6 +1,4 @@
"""passbook access helper classes""" """passbook access helper classes"""
from typing import List, Tuple
from django.contrib import messages from django.contrib import messages
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
@ -8,6 +6,7 @@ from structlog import get_logger
from passbook.core.models import Application, Provider, User from passbook.core.models import Application, Provider, User
from passbook.policies.engine import PolicyEngine from passbook.policies.engine import PolicyEngine
from passbook.policies.types import PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
@ -33,9 +32,7 @@ class AccessMixin:
) )
raise exc raise exc
def user_has_access( def user_has_access(self, application: Application, user: User) -> PolicyResult:
self, application: Application, user: User
) -> Tuple[bool, List[str]]:
"""Check if user has access to application.""" """Check if user has access to application."""
LOGGER.debug("Checking permissions", user=user, application=application) LOGGER.debug("Checking permissions", user=user, application=application)
policy_engine = PolicyEngine(application.policies.all(), user, self.request) policy_engine = PolicyEngine(application.policies.all(), user, self.request)

View File

@ -1,7 +1,7 @@
"""Flows Planner""" """Flows Planner"""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from time import time from time import time
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
from django.core.cache import cache from django.core.cache import cache
from django.http import HttpRequest from django.http import HttpRequest
@ -11,6 +11,7 @@ from passbook.core.models import User
from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException
from passbook.flows.models import Flow, Stage from passbook.flows.models import Flow, Stage
from passbook.policies.engine import PolicyEngine from passbook.policies.engine import PolicyEngine
from passbook.policies.types import PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
@ -51,8 +52,8 @@ class FlowPlanner:
self.use_cache = True self.use_cache = True
self.flow = flow self.flow = flow
def _check_flow_root_policies(self, request: HttpRequest) -> Tuple[bool, List[str]]: def _check_flow_root_policies(self, request: HttpRequest) -> PolicyResult:
engine = PolicyEngine(self.flow.policies.all(), request.user, request) engine = PolicyEngine(self.flow, request.user, request)
engine.build() engine.build()
return engine.result return engine.result
@ -64,9 +65,9 @@ class FlowPlanner:
LOGGER.debug("f(plan): Starting planning process", flow=self.flow) LOGGER.debug("f(plan): Starting planning process", flow=self.flow)
# First off, check the flow's direct policy bindings # First off, check the flow's direct policy bindings
# to make sure the user even has access to the flow # to make sure the user even has access to the flow
root_passing, root_passing_messages = self._check_flow_root_policies(request) root_result = self._check_flow_root_policies(request)
if not root_passing: if not root_result.passing:
raise FlowNonApplicableException(root_passing_messages) raise FlowNonApplicableException(*root_result.messages)
# Bit of a workaround here, if there is a pending user set in the default context # Bit of a workaround here, if there is a pending user set in the default context
# we use that user for our cache key # we use that user for our cache key
# to make sure they don't get the generic response # to make sure they don't get the generic response
@ -106,11 +107,10 @@ class FlowPlanner:
.select_related() .select_related()
): ):
binding = stage.flowstagebinding_set.get(flow__pk=self.flow.pk) binding = stage.flowstagebinding_set.get(flow__pk=self.flow.pk)
engine = PolicyEngine(binding.policies.all(), user, request) engine = PolicyEngine(binding, user, request)
engine.request.context = plan.context engine.request.context = plan.context
engine.build() engine.build()
passing, _ = engine.result if engine.passing:
if passing:
LOGGER.debug("f(plan): Stage passing", stage=stage, flow=self.flow) LOGGER.debug("f(plan): Stage passing", stage=stage, flow=self.flow)
plan.stages.append(stage) plan.stages.append(stage)
end_time = time() end_time = time()

View File

@ -8,9 +8,10 @@ from guardian.shortcuts import get_anonymous_user
from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException
from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding
from passbook.flows.planner import FlowPlanner from passbook.flows.planner import FlowPlanner
from passbook.policies.types import PolicyResult
from passbook.stages.dummy.models import DummyStage from passbook.stages.dummy.models import DummyStage
POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],)) POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False))
TIME_NOW_MOCK = MagicMock(return_value=3) TIME_NOW_MOCK = MagicMock(return_value=3)

View File

@ -9,9 +9,10 @@ from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding
from passbook.flows.planner import FlowPlan from passbook.flows.planner import FlowPlan
from passbook.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN from passbook.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN
from passbook.lib.config import CONFIG from passbook.lib.config import CONFIG
from passbook.policies.types import PolicyResult
from passbook.stages.dummy.models import DummyStage from passbook.stages.dummy.models import DummyStage
POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],)) POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False))
class TestFlowExecutor(TestCase): class TestFlowExecutor(TestCase):

View File

@ -1,5 +1,6 @@
"""Generic models""" """Generic models"""
from django.db import models from django.db import models
from model_utils.managers import InheritanceManager
class CreatedUpdatedModel(models.Model): class CreatedUpdatedModel(models.Model):
@ -10,3 +11,27 @@ class CreatedUpdatedModel(models.Model):
class Meta: class Meta:
abstract = True abstract = True
class InheritanceAutoManager(InheritanceManager):
"""Object manager which automatically selects the subclass"""
def get_queryset(self):
return super().get_queryset().select_subclasses()
class InheritanceForwardManyToOneDescriptor(
models.fields.related.ForwardManyToOneDescriptor
):
"""Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager."""
def get_queryset(self, **hints):
return self.field.remote_field.model.objects.db_manager(
hints=hints
).select_subclasses()
class InheritanceForeignKey(models.ForeignKey):
"""Custom ForeignKey that uses InheritanceForwardManyToOneDescriptor"""
forward_related_accessor_class = InheritanceForwardManyToOneDescriptor

View File

@ -12,7 +12,7 @@ class PolicyBindingSerializer(ModelSerializer):
class Meta: class Meta:
model = PolicyBinding model = PolicyBinding
fields = ["policy", "target", "enabled", "order"] fields = ["policy", "target", "enabled", "order", "timeout"]
class PolicyBindingViewSet(ModelViewSet): class PolicyBindingViewSet(ModelViewSet):

View File

@ -1,14 +1,14 @@
"""passbook policy engine""" """passbook policy engine"""
from multiprocessing import Pipe, set_start_method from multiprocessing import Pipe, set_start_method
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from typing import List, Optional, Tuple from typing import List, Optional
from django.core.cache import cache from django.core.cache import cache
from django.http import HttpRequest from django.http import HttpRequest
from structlog import get_logger from structlog import get_logger
from passbook.core.models import User from passbook.core.models import User
from passbook.policies.models import Policy from passbook.policies.models import Policy, PolicyBinding, PolicyBindingModel
from passbook.policies.process import PolicyProcess, cache_key from passbook.policies.process import PolicyProcess, cache_key
from passbook.policies.types import PolicyRequest, PolicyResult from passbook.policies.types import PolicyRequest, PolicyResult
@ -24,12 +24,14 @@ class PolicyProcessInfo:
process: PolicyProcess process: PolicyProcess
connection: Connection connection: Connection
result: Optional[PolicyResult] result: Optional[PolicyResult]
policy: Policy binding: PolicyBinding
def __init__(self, process: PolicyProcess, connection: Connection, policy: Policy): def __init__(
self, process: PolicyProcess, connection: Connection, binding: PolicyBinding
):
self.process = process self.process = process
self.connection = connection self.connection = connection
self.policy = policy self.binding = binding
self.result = None self.result = None
@ -37,54 +39,64 @@ class PolicyEngine:
"""Orchestrate policy checking, launch tasks and return result""" """Orchestrate policy checking, launch tasks and return result"""
use_cache: bool = True use_cache: bool = True
policies: List[Policy] = []
request: PolicyRequest request: PolicyRequest
__pbm: PolicyBindingModel
__cached_policies: List[PolicyResult] __cached_policies: List[PolicyResult]
__processes: List[PolicyProcessInfo] __processes: List[PolicyProcessInfo]
def __init__(self, policies, user: User, request: HttpRequest = None): def __init__(
self.policies = policies self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None
):
if not isinstance(pbm, PolicyBindingModel):
raise ValueError(f"{pbm} is not instance of PolicyBindingModel")
self.__pbm = pbm
self.request = PolicyRequest(user) self.request = PolicyRequest(user)
if request: if request:
self.request.http_request = request self.request.http_request = request
self.__cached_policies = [] self.__cached_policies = []
self.__processes = [] self.__processes = []
def _select_subclasses(self) -> List[Policy]: def _iter_bindings(self) -> List[PolicyBinding]:
"""Make sure all Policies are their respective classes""" """Make sure all Policies are their respective classes"""
return ( return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by(
Policy.objects.filter(pk__in=[x.pk for x in self.policies]) "order"
.select_subclasses()
.order_by("order")
) )
def _check_policy_type(self, policy: Policy):
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
# policy_type = type(policy)
if policy.__class__ == Policy:
raise TypeError(f"Policy '{policy}' is root type")
def build(self) -> "PolicyEngine": def build(self) -> "PolicyEngine":
"""Build task group""" """Build task group"""
for policy in self._select_subclasses(): for binding in self._iter_bindings():
cached_policy = cache.get(cache_key(policy, self.request.user), None) self._check_policy_type(binding.policy)
policy = binding.policy
cached_policy = cache.get(cache_key(binding, self.request.user), None)
if cached_policy and self.use_cache: if cached_policy and self.use_cache:
LOGGER.debug("P_ENG: Taking result from cache", policy=policy) LOGGER.debug("P_ENG: Taking result from cache", policy=policy)
self.__cached_policies.append(cached_policy) self.__cached_policies.append(cached_policy)
continue continue
LOGGER.debug("P_ENG: Evaluating policy", policy=policy) LOGGER.debug("P_ENG: Evaluating policy", policy=policy)
our_end, task_end = Pipe(False) our_end, task_end = Pipe(False)
task = PolicyProcess(policy, self.request, task_end) task = PolicyProcess(binding, self.request, task_end)
LOGGER.debug("P_ENG: Starting Process", policy=policy) LOGGER.debug("P_ENG: Starting Process", policy=policy)
task.start() task.start()
self.__processes.append( self.__processes.append(
PolicyProcessInfo(process=task, connection=our_end, policy=policy) PolicyProcessInfo(process=task, connection=our_end, binding=binding)
) )
# If all policies are cached, we have an empty list here. # If all policies are cached, we have an empty list here.
for proc_info in self.__processes: for proc_info in self.__processes:
proc_info.process.join(proc_info.policy.timeout) proc_info.process.join(proc_info.binding.timeout)
# Only call .recv() if no result is saved, otherwise we just deadlock here # Only call .recv() if no result is saved, otherwise we just deadlock here
if not proc_info.result: if not proc_info.result:
proc_info.result = proc_info.connection.recv() proc_info.result = proc_info.connection.recv()
return self return self
@property @property
def result(self) -> Tuple[bool, List[str]]: def result(self) -> PolicyResult:
"""Get policy-checking result""" """Get policy-checking result"""
messages: List[str] = [] messages: List[str] = []
process_results: List[PolicyResult] = [ process_results: List[PolicyResult] = [
@ -95,10 +107,10 @@ class PolicyEngine:
if result.messages: if result.messages:
messages += result.messages messages += result.messages
if not result.passing: if not result.passing:
return False, messages return PolicyResult(False, *messages)
return True, messages return PolicyResult(True, *messages)
@property @property
def passing(self) -> bool: def passing(self) -> bool:
"""Only get true/false if user passes""" """Only get true/false if user passes"""
return self.result[0] return self.result.passing

View File

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from jinja2 import Undefined from jinja2 import Undefined
from jinja2.exceptions import TemplateSyntaxError, UndefinedError from jinja2.exceptions import TemplateSyntaxError
from jinja2.nativetypes import NativeEnvironment from jinja2.nativetypes import NativeEnvironment
from requests import Session from requests import Session
from structlog import get_logger from structlog import get_logger
@ -90,7 +90,8 @@ class Evaluator:
if result: if result:
return PolicyResult(bool(result)) return PolicyResult(bool(result))
return PolicyResult(False) return PolicyResult(False)
except UndefinedError as exc: except Exception as exc: # pylint: disable=broad-except
LOGGER.warning("Expression error", exc=exc)
return PolicyResult(False, str(exc)) return PolicyResult(False, str(exc))
def validate(self, expression: str): def validate(self, expression: str):

View File

@ -3,8 +3,8 @@ from django import forms
from passbook.policies.models import PolicyBinding, PolicyBindingModel from passbook.policies.models import PolicyBinding, PolicyBindingModel
GENERAL_FIELDS = ["name", "negate", "order", "timeout"] GENERAL_FIELDS = ["name"]
GENERAL_SERIALIZER_FIELDS = ["pk", "name", "negate", "order", "timeout"] GENERAL_SERIALIZER_FIELDS = ["pk", "name"]
class PolicyBindingForm(forms.ModelForm): class PolicyBindingForm(forms.ModelForm):
@ -18,9 +18,4 @@ class PolicyBindingForm(forms.ModelForm):
class Meta: class Meta:
model = PolicyBinding model = PolicyBinding
fields = [ fields = ["enabled", "policy", "target", "order", "timeout"]
"enabled",
"policy",
"target",
"order",
]

View File

@ -0,0 +1,58 @@
# Generated by Django 3.0.6 on 2020-05-28 16:47
import django.db.models.deletion
from django.db import migrations, models
import passbook.lib.models
class Migration(migrations.Migration):
dependencies = [
("passbook_policies", "0001_initial"),
]
operations = [
migrations.AlterModelOptions(
name="policy",
options={
"base_manager_name": "objects",
"verbose_name": "Policy",
"verbose_name_plural": "Policies",
},
),
migrations.RemoveField(model_name="policy", name="negate",),
migrations.RemoveField(model_name="policy", name="order",),
migrations.RemoveField(model_name="policy", name="timeout",),
migrations.AddField(
model_name="policybinding",
name="negate",
field=models.BooleanField(
default=False,
help_text="Negates the outcome of the policy. Messages are unaffected.",
),
),
migrations.AddField(
model_name="policybinding",
name="timeout",
field=models.IntegerField(
default=30,
help_text="Timeout after which Policy execution is terminated.",
),
),
migrations.AlterField(
model_name="policybinding", name="order", field=models.IntegerField(),
),
migrations.AlterField(
model_name="policybinding",
name="policy",
field=passbook.lib.models.InheritanceForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="+",
to="passbook_policies.Policy",
),
),
migrations.AlterUniqueTogether(
name="policybinding", unique_together={("policy", "target", "order")},
),
]

View File

@ -5,7 +5,11 @@ from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from model_utils.managers import InheritanceManager from model_utils.managers import InheritanceManager
from passbook.lib.models import CreatedUpdatedModel from passbook.lib.models import (
CreatedUpdatedModel,
InheritanceAutoManager,
InheritanceForeignKey,
)
from passbook.policies.exceptions import PolicyException from passbook.policies.exceptions import PolicyException
from passbook.policies.types import PolicyRequest, PolicyResult from passbook.policies.types import PolicyRequest, PolicyResult
@ -22,7 +26,6 @@ class PolicyBindingModel(models.Model):
objects = InheritanceManager() objects = InheritanceManager()
class Meta: class Meta:
verbose_name = _("Policy Binding Model") verbose_name = _("Policy Binding Model")
verbose_name_plural = _("Policy Binding Models") verbose_name_plural = _("Policy Binding Models")
@ -36,13 +39,19 @@ class PolicyBinding(models.Model):
enabled = models.BooleanField(default=True) enabled = models.BooleanField(default=True)
policy = models.ForeignKey("Policy", on_delete=models.CASCADE, related_name="+") policy = InheritanceForeignKey("Policy", on_delete=models.CASCADE, related_name="+")
target = models.ForeignKey( target = models.ForeignKey(
PolicyBindingModel, on_delete=models.CASCADE, related_name="+" PolicyBindingModel, on_delete=models.CASCADE, related_name="+"
) )
negate = models.BooleanField(
default=False,
help_text=_("Negates the outcome of the policy. Messages are unaffected."),
)
timeout = models.IntegerField(
default=30, help_text=_("Timeout after which Policy execution is terminated.")
)
# default value and non-unique for compatibility order = models.IntegerField()
order = models.IntegerField(default=0)
def __str__(self) -> str: def __str__(self) -> str:
return f"PolicyBinding policy={self.policy} target={self.target} order={self.order}" return f"PolicyBinding policy={self.policy} target={self.target} order={self.order}"
@ -51,6 +60,7 @@ class PolicyBinding(models.Model):
verbose_name = _("Policy Binding") verbose_name = _("Policy Binding")
verbose_name_plural = _("Policy Bindings") verbose_name_plural = _("Policy Bindings")
unique_together = ("policy", "target", "order")
class Policy(CreatedUpdatedModel): class Policy(CreatedUpdatedModel):
@ -60,11 +70,8 @@ class Policy(CreatedUpdatedModel):
policy_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) policy_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
name = models.TextField(blank=True, null=True) name = models.TextField(blank=True, null=True)
negate = models.BooleanField(default=False)
order = models.IntegerField(default=0)
timeout = models.IntegerField(default=30)
objects = InheritanceManager() objects = InheritanceAutoManager()
def __str__(self): def __str__(self):
return f"Policy {self.name}" return f"Policy {self.name}"
@ -72,3 +79,9 @@ class Policy(CreatedUpdatedModel):
def passes(self, request: PolicyRequest) -> PolicyResult: def passes(self, request: PolicyRequest) -> PolicyResult:
"""Check if user instance passes this policy""" """Check if user instance passes this policy"""
raise PolicyException() raise PolicyException()
class Meta:
base_manager_name = "objects"
verbose_name = _("Policy")
verbose_name_plural = _("Policies")

View File

@ -8,15 +8,15 @@ from structlog import get_logger
from passbook.core.models import User from passbook.core.models import User
from passbook.policies.exceptions import PolicyException from passbook.policies.exceptions import PolicyException
from passbook.policies.models import Policy from passbook.policies.models import PolicyBinding
from passbook.policies.types import PolicyRequest, PolicyResult from passbook.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
def cache_key(policy: Policy, user: Optional[User] = None) -> str: def cache_key(binding: PolicyBinding, user: Optional[User] = None) -> str:
"""Generate Cache key for policy""" """Generate Cache key for policy"""
prefix = f"policy_{policy.pk}" prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}"
if user: if user:
prefix += f"#{user.pk}" prefix += f"#{user.pk}"
return prefix return prefix
@ -26,40 +26,50 @@ class PolicyProcess(Process):
"""Evaluate a single policy within a seprate process""" """Evaluate a single policy within a seprate process"""
connection: Connection connection: Connection
policy: Policy binding: PolicyBinding
request: PolicyRequest request: PolicyRequest
def __init__(self, policy: Policy, request: PolicyRequest, connection: Connection): def __init__(
self,
binding: PolicyBinding,
request: PolicyRequest,
connection: Optional[Connection],
):
super().__init__() super().__init__()
self.policy = policy self.binding = binding
self.request = request self.request = request
self.connection = connection if connection:
self.connection = connection
def run(self): def execute(self) -> PolicyResult:
"""Task wrapper to run policy checking""" """Run actual policy, returns result"""
LOGGER.debug( LOGGER.debug(
"P_ENG(proc): Running policy", "P_ENG(proc): Running policy",
policy=self.policy, policy=self.binding.policy,
user=self.request.user, user=self.request.user,
process="PolicyProcess", process="PolicyProcess",
) )
try: try:
policy_result = self.policy.passes(self.request) policy_result = self.binding.policy.passes(self.request)
except PolicyException as exc: except PolicyException as exc:
LOGGER.debug("P_ENG(proc): error", exc=exc) LOGGER.debug("P_ENG(proc): error", exc=exc)
policy_result = PolicyResult(False, str(exc)) policy_result = PolicyResult(False, str(exc))
# Invert result if policy.negate is set # Invert result if policy.negate is set
if self.policy.negate: if self.binding.negate:
policy_result.passing = not policy_result.passing policy_result.passing = not policy_result.passing
LOGGER.debug( LOGGER.debug(
"P_ENG(proc): Finished", "P_ENG(proc): Finished",
policy=self.policy, policy=self.binding.policy,
result=policy_result, result=policy_result,
process="PolicyProcess", process="PolicyProcess",
passing=policy_result.passing, passing=policy_result.passing,
user=self.request.user, user=self.request.user,
) )
key = cache_key(self.policy, self.request.user) key = cache_key(self.binding, self.request.user)
cache.set(key, policy_result) cache.set(key, policy_result)
LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key) LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key)
self.connection.send(policy_result) return policy_result
def run(self):
"""Task wrapper to run policy checking"""
self.connection.send(self.execute())

View File

@ -5,7 +5,8 @@ from django.test import TestCase
from passbook.core.models import User from passbook.core.models import User
from passbook.policies.dummy.models import DummyPolicy from passbook.policies.dummy.models import DummyPolicy
from passbook.policies.engine import PolicyEngine from passbook.policies.engine import PolicyEngine
from passbook.policies.models import Policy from passbook.policies.expression.models import ExpressionPolicy
from passbook.policies.models import Policy, PolicyBinding, PolicyBindingModel
class PolicyTestEngine(TestCase): class PolicyTestEngine(TestCase):
@ -20,40 +21,64 @@ class PolicyTestEngine(TestCase):
self.policy_true = DummyPolicy.objects.create( self.policy_true = DummyPolicy.objects.create(
result=True, wait_min=0, wait_max=1 result=True, wait_min=0, wait_max=1
) )
self.policy_negate = DummyPolicy.objects.create( self.policy_wrong_type = Policy.objects.create(name="wrong_type")
negate=True, result=True, wait_min=0, wait_max=1 self.policy_raises = ExpressionPolicy.objects.create(
name="raises", expression="{{ 0/0 }}"
) )
self.policy_raises = Policy.objects.create(name="raises")
def test_engine_empty(self): def test_engine_empty(self):
"""Ensure empty policy list passes""" """Ensure empty policy list passes"""
engine = PolicyEngine([], self.user) pbm = PolicyBindingModel.objects.create()
self.assertEqual(engine.build().passing, True) engine = PolicyEngine(pbm, self.user)
result = engine.build().result
self.assertEqual(result.passing, True)
self.assertEqual(result.messages, ())
def test_engine(self): def test_engine(self):
"""Ensure all policies passes (Mix of false and true -> false)""" """Ensure all policies passes (Mix of false and true -> false)"""
engine = PolicyEngine( pbm = PolicyBindingModel.objects.create()
DummyPolicy.objects.filter(negate__exact=False), self.user PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
) PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1)
self.assertEqual(engine.build().passing, False) engine = PolicyEngine(pbm, self.user)
result = engine.build().result
self.assertEqual(result.passing, False)
self.assertEqual(result.messages, ("dummy",))
def test_engine_negate(self): def test_engine_negate(self):
"""Test negate flag""" """Test negate flag"""
engine = PolicyEngine(DummyPolicy.objects.filter(negate__exact=True), self.user) pbm = PolicyBindingModel.objects.create()
self.assertEqual(engine.build().passing, False) PolicyBinding.objects.create(
target=pbm, policy=self.policy_true, negate=True, order=0
)
engine = PolicyEngine(pbm, self.user)
result = engine.build().result
self.assertEqual(result.passing, False)
self.assertEqual(result.messages, ("dummy",))
def test_engine_policy_error(self): def test_engine_policy_error(self):
"""Test negate flag""" """Test policy raising an error flag"""
engine = PolicyEngine(Policy.objects.filter(name="raises"), self.user) pbm = PolicyBindingModel.objects.create()
self.assertEqual(engine.build().passing, False) PolicyBinding.objects.create(target=pbm, policy=self.policy_raises, order=0)
engine = PolicyEngine(pbm, self.user)
result = engine.build().result
self.assertEqual(result.passing, False)
self.assertEqual(result.messages, ("division by zero",))
def test_engine_policy_type(self):
"""Test invalid policy type"""
pbm = PolicyBindingModel.objects.create()
PolicyBinding.objects.create(target=pbm, policy=self.policy_wrong_type, order=0)
with self.assertRaises(TypeError):
engine = PolicyEngine(pbm, self.user)
engine.build()
def test_engine_cache(self): def test_engine_cache(self):
"""Ensure empty policy list passes""" """Ensure empty policy list passes"""
engine = PolicyEngine( pbm = PolicyBindingModel.objects.create()
DummyPolicy.objects.filter(negate__exact=False), self.user PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
) engine = PolicyEngine(pbm, self.user)
self.assertEqual(len(cache.keys("policy_*")), 0) self.assertEqual(len(cache.keys("policy_*")), 0)
self.assertEqual(engine.build().passing, False) self.assertEqual(engine.build().passing, False)
self.assertEqual(len(cache.keys("policy_*")), 2) self.assertEqual(len(cache.keys("policy_*")), 1)
self.assertEqual(engine.build().passing, False) self.assertEqual(engine.build().passing, False)
self.assertEqual(len(cache.keys("policy_*")), 2) self.assertEqual(len(cache.keys("policy_*")), 1)

View File

@ -50,9 +50,9 @@ class PassbookAuthorizationView(AccessMixin, AuthorizationView):
provider.save() provider.save()
self._application = application self._application = application
# Check permissions # Check permissions
passing, policy_messages = self.user_has_access(self._application, request.user) result = self.user_has_access(self._application, request.user)
if not passing: if not result.passing:
for policy_message in policy_messages: for policy_message in result.messages:
messages.error(request, policy_message) messages.error(request, policy_message)
return redirect("passbook_providers_oauth:oauth2-permission-denied") return redirect("passbook_providers_oauth:oauth2-permission-denied")
# Some clients don't pass response_type, so we default to code # Some clients don't pass response_type, so we default to code

View File

@ -18,7 +18,7 @@ LOGGER = get_logger()
def client_related_provider(client: Client) -> Optional[Provider]: def client_related_provider(client: Client) -> Optional[Provider]:
"""Lookup related Application from Client""" """Lookup related Application from Client"""
# because oidc_provider is also used by app_gw, we can't be # because oidc_provider is also used by app_gw, we can't be
# sure an OpenIDPRovider instance exists. hence we look through all related models # sure an OpenIDProvider instance exists. hence we look through all related models
# and choose the one that inherits from Provider, which is guaranteed to # and choose the one that inherits from Provider, which is guaranteed to
# have the application property # have the application property
collector = Collector(using="default") collector = Collector(using="default")
@ -50,9 +50,9 @@ def check_permissions(
policy_engine.build() policy_engine.build()
# Check permissions # Check permissions
passing, policy_messages = policy_engine.result result = policy_engine.result
if not passing: if not result.passing:
for policy_message in policy_messages: for policy_message in result.messages:
messages.error(request, policy_message) messages.error(request, policy_message)
return redirect("passbook_providers_oauth:oauth2-permission-denied") return redirect("passbook_providers_oauth:oauth2-permission-denied")

View File

@ -55,9 +55,9 @@ class PromptForm(forms.Form):
def clean(self): def clean(self):
cleaned_data = super().clean() cleaned_data = super().clean()
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
engine = PolicyEngine(self.stage.policies.all(), user) engine = PolicyEngine(self.stage, user)
engine.request.context = cleaned_data engine.request.context = cleaned_data
engine.build() engine.build()
passing, messages = engine.result result = engine.result
if not passing: if not result.passing:
raise forms.ValidationError(messages) raise forms.ValidationError(result.messages)

View File

@ -139,7 +139,7 @@ class TestPromptStage(TestCase):
expr_policy = ExpressionPolicy.objects.create( expr_policy = ExpressionPolicy.objects.create(
name="validate-form", expression=expr name="validate-form", expression=expr
) )
PolicyBinding.objects.create(policy=expr_policy, target=self.stage) PolicyBinding.objects.create(policy=expr_policy, target=self.stage, order=0)
form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data) form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
self.assertEqual(form.is_valid(), True) self.assertEqual(form.is_valid(), True)
return form return form
@ -151,7 +151,7 @@ class TestPromptStage(TestCase):
expr_policy = ExpressionPolicy.objects.create( expr_policy = ExpressionPolicy.objects.create(
name="validate-form", expression=expr name="validate-form", expression=expr
) )
PolicyBinding.objects.create(policy=expr_policy, target=self.stage) PolicyBinding.objects.create(policy=expr_policy, target=self.stage, order=0)
form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data) form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
self.assertEqual(form.is_valid(), False) self.assertEqual(form.is_valid(), False)
return form return form

View File

@ -837,7 +837,7 @@ paths:
parameters: parameters:
- name: policy_uuid - name: policy_uuid
in: path in: path
description: A UUID string identifying this policy. description: A UUID string identifying this Policy.
required: true required: true
type: string type: string
format: uuid format: uuid
@ -5079,19 +5079,6 @@ definitions:
title: Name title: Name
type: string type: string
x-nullable: true x-nullable: true
negate:
title: Negate
type: boolean
order:
title: Order
type: integer
maximum: 2147483647
minimum: -2147483648
timeout:
title: Timeout
type: integer
maximum: 2147483647
minimum: -2147483648
__type__: __type__:
title: 'type ' title: 'type '
type: string type: string
@ -5100,6 +5087,7 @@ definitions:
required: required:
- policy - policy
- target - target
- order
type: object type: object
properties: properties:
policy: policy:
@ -5118,6 +5106,12 @@ definitions:
type: integer type: integer
maximum: 2147483647 maximum: 2147483647
minimum: -2147483648 minimum: -2147483648
timeout:
title: Timeout
description: Timeout after which Policy execution is terminated.
type: integer
maximum: 2147483647
minimum: -2147483648
DummyPolicy: DummyPolicy:
type: object type: object
properties: properties:
@ -5130,19 +5124,6 @@ definitions:
title: Name title: Name
type: string type: string
x-nullable: true x-nullable: true
negate:
title: Negate
type: boolean
order:
title: Order
type: integer
maximum: 2147483647
minimum: -2147483648
timeout:
title: Timeout
type: integer
maximum: 2147483647
minimum: -2147483648
result: result:
title: Result title: Result
type: boolean type: boolean
@ -5170,19 +5151,6 @@ definitions:
title: Name title: Name
type: string type: string
x-nullable: true x-nullable: true
negate:
title: Negate
type: boolean
order:
title: Order
type: integer
maximum: 2147483647
minimum: -2147483648
timeout:
title: Timeout
type: integer
maximum: 2147483647
minimum: -2147483648
expression: expression:
title: Expression title: Expression
type: string type: string
@ -5199,19 +5167,6 @@ definitions:
title: Name title: Name
type: string type: string
x-nullable: true x-nullable: true
negate:
title: Negate
type: boolean
order:
title: Order
type: integer
maximum: 2147483647
minimum: -2147483648
timeout:
title: Timeout
type: integer
maximum: 2147483647
minimum: -2147483648
allowed_count: allowed_count:
title: Allowed count title: Allowed count
type: integer type: integer
@ -5231,19 +5186,6 @@ definitions:
title: Name title: Name
type: string type: string
x-nullable: true x-nullable: true
negate:
title: Negate
type: boolean
order:
title: Order
type: integer
maximum: 2147483647
minimum: -2147483648
timeout:
title: Timeout
type: integer
maximum: 2147483647
minimum: -2147483648
amount_uppercase: amount_uppercase:
title: Amount uppercase title: Amount uppercase
type: integer type: integer
@ -5286,19 +5228,6 @@ definitions:
title: Name title: Name
type: string type: string
x-nullable: true x-nullable: true
negate:
title: Negate
type: boolean
order:
title: Order
type: integer
maximum: 2147483647
minimum: -2147483648
timeout:
title: Timeout
type: integer
maximum: 2147483647
minimum: -2147483648
days: days:
title: Days title: Days
type: integer type: integer
@ -5319,19 +5248,6 @@ definitions:
title: Name title: Name
type: string type: string
x-nullable: true x-nullable: true
negate:
title: Negate
type: boolean
order:
title: Order
type: integer
maximum: 2147483647
minimum: -2147483648
timeout:
title: Timeout
type: integer
maximum: 2147483647
minimum: -2147483648
check_ip: check_ip:
title: Check ip title: Check ip
type: boolean type: boolean