diff --git a/.gitignore b/.gitignore index e26c51d5c..31758ac03 100644 --- a/.gitignore +++ b/.gitignore @@ -191,3 +191,4 @@ pip-selfcheck.json # End of https://www.gitignore.io/api/python,django /static/ local.env.yml +.vscode/ diff --git a/passbook/admin/views/policy.py b/passbook/admin/views/policy.py index 355ea916f..355717126 100644 --- a/passbook/admin/views/policy.py +++ b/passbook/admin/views/policy.py @@ -11,8 +11,8 @@ from django.views.generic.detail import DetailView from passbook.admin.forms.policies import PolicyTestForm from passbook.admin.mixins import AdminRequiredMixin from passbook.core.models import Policy -from passbook.core.policies import PolicyEngine from passbook.lib.utils.reflection import path_to_class +from passbook.policy.engine import PolicyEngine class PolicyListView(AdminRequiredMixin, ListView): diff --git a/passbook/app_gw/middleware.py b/passbook/app_gw/middleware.py index 919e5c6b5..6cd2ac729 100644 --- a/passbook/app_gw/middleware.py +++ b/passbook/app_gw/middleware.py @@ -27,7 +27,7 @@ class ApplicationGatewayMiddleware: handler = RequestHandler(app_gw, request) if not handler.check_permission(): - to_url = 'https://%s/?next=%s' % (CONFIG.get('domains')[0], request.get_full_path()) + to_url = 'https://%s/?next=%s' % (CONFIG.y('domains')[0], request.get_full_path()) return RedirectView.as_view(url=to_url)(request) return handler.get_response() diff --git a/passbook/app_gw/proxy/handler.py b/passbook/app_gw/proxy/handler.py index ea6b376c4..15ca0d64d 100644 --- a/passbook/app_gw/proxy/handler.py +++ b/passbook/app_gw/proxy/handler.py @@ -15,7 +15,7 @@ from passbook.app_gw.proxy.response import get_django_response from passbook.app_gw.proxy.rewrite import Rewriter from passbook.app_gw.proxy.utils import encode_items, normalize_request_headers from passbook.core.models import Application -from passbook.core.policies import PolicyEngine +from passbook.policy.engine import PolicyEngine SESSION_UPSTREAM_KEY = 'passbook_app_gw_upstream' IGNORED_HOSTNAMES_KEY = 'passbook_app_gw_ignored' diff --git a/passbook/core/apps.py b/passbook/core/apps.py index 3329ec500..198a3eb43 100644 --- a/passbook/core/apps.py +++ b/passbook/core/apps.py @@ -17,7 +17,7 @@ class PassbookCoreConfig(AppConfig): mountpoint = '' def ready(self): - import_module('passbook.core.policies') + import_module('passbook.policy.engine') factors_to_load = CONFIG.y('passbook.factors', []) for factors_to_load in factors_to_load: try: diff --git a/passbook/core/auth/view.py b/passbook/core/auth/view.py index ac673298e..4173abfe9 100644 --- a/passbook/core/auth/view.py +++ b/passbook/core/auth/view.py @@ -1,5 +1,6 @@ """passbook multi-factor authentication engine""" from logging import getLogger +from typing import List, Tuple from django.contrib.auth import login from django.contrib.auth.mixins import UserPassesTestMixin @@ -8,10 +9,10 @@ from django.utils.http import urlencode from django.views.generic import View from passbook.core.models import Factor, User -from passbook.core.policies import PolicyEngine from passbook.core.views.utils import PermissionDeniedView from passbook.lib.utils.reflection import class_to_path, path_to_class from passbook.lib.utils.urls import is_url_absolute +from passbook.policy.engine import PolicyEngine LOGGER = getLogger(__name__) @@ -31,12 +32,12 @@ class AuthenticationView(UserPassesTestMixin, View): SESSION_USER_BACKEND = 'passbook_user_backend' SESSION_IS_SSO_LOGIN = 'passbook_sso_login' - pending_user = None - pending_factors = [] + pending_user: User + pending_factors: List[Tuple[str, str]] = [] - _current_factor_class = None + _current_factor_class: Factor - current_factor = None + current_factor: Factor # Allow only not authenticated users to login def test_func(self): diff --git a/passbook/core/models.py b/passbook/core/models.py index fb617c60e..1b8e274a9 100644 --- a/passbook/core/models.py +++ b/passbook/core/models.py @@ -4,7 +4,7 @@ from datetime import timedelta from logging import getLogger from random import SystemRandom from time import sleep -from typing import Tuple, Union +from typing import List from uuid import uuid4 from django.contrib.auth.models import AbstractUser @@ -25,6 +25,20 @@ def default_nonce_duration(): """Default duration a Nonce is valid""" return now() + timedelta(hours=4) + +class PolicyResult: + """Small data-class to hold policy results""" + + passing: bool = False + messages: List[str] = [] + + def __init__(self, passing: bool, *messages: str): + self.passing = passing + self.messages = messages + + def __str__(self): + return f"" + class Group(UUIDModel): """Custom Group model which supports a basic hierarchy""" @@ -229,7 +243,7 @@ class Policy(UUIDModel, CreatedUpdatedModel): return self.name return "%s action %s" % (self.name, self.action) - def passes(self, user: User) -> Union[bool, Tuple[bool, str]]: + def passes(self, user: User) -> PolicyResult: """Check if user instance passes this policy""" raise NotImplementedError() @@ -273,7 +287,7 @@ class FieldMatcherPolicy(Policy): description = "%s: %s" % (self.name, description) return description - def passes(self, user: User) -> Union[bool, Tuple[bool, str]]: + def passes(self, user: User) -> PolicyResult: """Check if user instance passes this role""" if not hasattr(user, self.user_field): raise ValueError("Field does not exist") @@ -294,7 +308,7 @@ class FieldMatcherPolicy(Policy): passes = user_field_value == self.value LOGGER.debug("User got '%r'", passes) - return passes + return PolicyResult(passes) class Meta: @@ -313,10 +327,10 @@ class PasswordPolicy(Policy): form = 'passbook.core.forms.policies.PasswordPolicyForm' - def passes(self, user: User) -> Union[bool, Tuple[bool, str]]: + def passes(self, user: User) -> PolicyResult: # Only check if password is being set if not hasattr(user, '__password__'): - return True + return PolicyResult(True) password = getattr(user, '__password__') filter_regex = r'' @@ -329,8 +343,8 @@ class PasswordPolicy(Policy): result = bool(re.compile(filter_regex).match(password)) LOGGER.debug("User got %r", result) if not result: - return result, self.error_message - return result + return PolicyResult(result, self.error_message) + return PolicyResult(result) class Meta: @@ -364,7 +378,7 @@ class WebhookPolicy(Policy): form = 'passbook.core.forms.policies.WebhookPolicyForm' - def passes(self, user: User): + def passes(self, user: User) -> PolicyResult: """Call webhook asynchronously and report back""" raise NotImplementedError() @@ -383,12 +397,12 @@ class DebugPolicy(Policy): form = 'passbook.core.forms.policies.DebugPolicyForm' - def passes(self, user: User): + def passes(self, user: User) -> PolicyResult: """Wait random time then return result""" wait = SystemRandom().randrange(self.wait_min, self.wait_max) LOGGER.debug("Policy '%s' waiting for %ds", self.name, wait) sleep(wait) - return self.result, 'Debugging' + return PolicyResult(self.result, 'Debugging') class Meta: @@ -402,8 +416,8 @@ class GroupMembershipPolicy(Policy): form = 'passbook.core.forms.policies.GroupMembershipPolicyForm' - def passes(self, user: User) -> Union[bool, Tuple[bool, str]]: - return self.group.user_set.filter(pk=user.pk).exists() + def passes(self, user: User) -> PolicyResult: + return PolicyResult(self.group.user_set.filter(pk=user.pk).exists()) class Meta: @@ -415,10 +429,10 @@ class SSOLoginPolicy(Policy): form = 'passbook.core.forms.policies.SSOLoginPolicyForm' - def passes(self, user): + def passes(self, user) -> PolicyResult: """Check if user instance passes this policy""" from passbook.core.auth.view import AuthenticationView - return user.session.get(AuthenticationView.SESSION_IS_SSO_LOGIN, False), "" + return PolicyResult(user.session.get(AuthenticationView.SESSION_IS_SSO_LOGIN, False)) class Meta: diff --git a/passbook/core/policies.py b/passbook/core/policies.py deleted file mode 100644 index e50d51dbd..000000000 --- a/passbook/core/policies.py +++ /dev/null @@ -1,134 +0,0 @@ -"""passbook core policy engine""" -from logging import getLogger - -from amqp.exceptions import UnexpectedFrame -from celery import group -from celery.exceptions import TimeoutError as CeleryTimeoutError -from django.core.cache import cache -from ipware import get_client_ip - -from passbook.core.models import Policy, User -from passbook.root.celery import CELERY_APP - -LOGGER = getLogger(__name__) - -def _cache_key(policy, user): - return "policy_%s#%s" % (policy.uuid, user.pk) - -@CELERY_APP.task() -def _policy_engine_task(user_pk, policy_pk, **kwargs): - """Task wrapper to run policy checking""" - if not user_pk: - raise ValueError() - policy_obj = Policy.objects.filter(pk=policy_pk).select_subclasses().first() - user_obj = User.objects.get(pk=user_pk) - for key, value in kwargs.items(): - setattr(user_obj, key, value) - LOGGER.debug("Running policy `%s`#%s for user %s...", policy_obj.name, - policy_obj.pk.hex, user_obj) - policy_result = policy_obj.passes(user_obj) - # Handle policy result correctly if result, message or just result - message = None - if isinstance(policy_result, (tuple, list)): - policy_result, message = policy_result - # Invert result if policy.negate is set - if policy_obj.negate: - policy_result = not policy_result - LOGGER.debug("Policy %r#%s got %s", policy_obj.name, policy_obj.pk.hex, policy_result) - cache_key = _cache_key(policy_obj, user_obj) - cache.set(cache_key, (policy_obj.action, policy_result, message)) - LOGGER.debug("Cached entry as %s", cache_key) - return policy_obj.action, policy_result, message - -class PolicyEngine: - """Orchestrate policy checking, launch tasks and return result""" - - __group = None - __cached = None - - policies = None - __get_timeout = 0 - __request = None - __user = None - - def __init__(self, policies): - self.policies = policies - self.__request = None - self.__user = None - - def for_user(self, user): - """Check policies for user""" - self.__user = user - return self - - def with_request(self, request): - """Set request""" - self.__request = request - return self - - def build(self): - """Build task group""" - if not self.__user: - raise ValueError("User not set.") - signatures = [] - cached_policies = [] - kwargs = { - '__password__': getattr(self.__user, '__password__', None), - 'session': dict(getattr(self.__request, 'session', {}).items()), - } - if self.__request: - kwargs['remote_ip'], _ = get_client_ip(self.__request) - if not kwargs['remote_ip']: - kwargs['remote_ip'] = '255.255.255.255' - for policy in self.policies: - cached_policy = cache.get(_cache_key(policy, self.__user), None) - if cached_policy: - LOGGER.debug("Taking result from cache for %s", policy.pk.hex) - cached_policies.append(cached_policy) - else: - LOGGER.debug("Evaluating policy %s", policy.pk.hex) - signatures.append(_policy_engine_task.signature( - args=(self.__user.pk, policy.pk.hex), - kwargs=kwargs, - time_limit=policy.timeout)) - self.__get_timeout += policy.timeout - LOGGER.debug("Set total policy timeout to %r", self.__get_timeout) - # If all policies are cached, we have an empty list here. - if signatures: - self.__group = group(signatures)() - self.__get_timeout += 3 - self.__get_timeout = (self.__get_timeout / len(self.policies)) * 1.5 - self.__cached = cached_policies - return self - - @property - def result(self): - """Get policy-checking result""" - messages = [] - result = [] - try: - if self.__group: - # ValueError can be thrown from _policy_engine_task when user is None - result += self.__group.get(timeout=self.__get_timeout) - result += self.__cached - except ValueError as exc: - # ValueError can be thrown from _policy_engine_task when user is None - return False, [str(exc)] - except UnexpectedFrame as exc: - return False, [str(exc)] - except CeleryTimeoutError as exc: - return False, [str(exc)] - for policy_action, policy_result, policy_message in result: - passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \ - (policy_action == Policy.ACTION_DENY and not policy_result) - LOGGER.debug('Action=%s, Result=%r => %r', policy_action, policy_result, passing) - if policy_message: - messages.append(policy_message) - if not passing: - return False, messages - return True, messages - - @property - def passing(self): - """Only get true/false if user passes""" - return self.result[0] diff --git a/passbook/core/signals.py b/passbook/core/signals.py index 53b30cb6f..6806690f7 100644 --- a/passbook/core/signals.py +++ b/passbook/core/signals.py @@ -20,7 +20,7 @@ password_changed = Signal(providing_args=['user', 'password']) def password_policy_checker(sender, password, **kwargs): """Run password through all password policies which are applied to the user""" from passbook.core.models import PasswordFactor - from passbook.core.policies import PolicyEngine + from passbook.policy.engine import PolicyEngine setattr(sender, '__password__', password) _all_factors = PasswordFactor.objects.filter(enabled=True).order_by('order') for factor in _all_factors: diff --git a/passbook/core/templatetags/passbook_user_settings.py b/passbook/core/templatetags/passbook_user_settings.py index 7aeb01620..05bfc0ae6 100644 --- a/passbook/core/templatetags/passbook_user_settings.py +++ b/passbook/core/templatetags/passbook_user_settings.py @@ -3,7 +3,7 @@ from django import template from passbook.core.models import Factor, Source -from passbook.core.policies import PolicyEngine +from passbook.policy.engine import PolicyEngine register = template.Library() diff --git a/passbook/core/views/access.py b/passbook/core/views/access.py index 293f2324e..054181371 100644 --- a/passbook/core/views/access.py +++ b/passbook/core/views/access.py @@ -5,7 +5,7 @@ from django.contrib import messages from django.utils.translation import gettext as _ from passbook.core.models import Application -from passbook.core.policies import PolicyEngine +from passbook.policy.engine import PolicyEngine LOGGER = getLogger(__name__) diff --git a/passbook/core/views/overview.py b/passbook/core/views/overview.py index f41972fb2..ad788b607 100644 --- a/passbook/core/views/overview.py +++ b/passbook/core/views/overview.py @@ -4,7 +4,7 @@ from django.contrib.auth.mixins import LoginRequiredMixin from django.views.generic import TemplateView from passbook.core.models import Application -from passbook.core.policies import PolicyEngine +from passbook.policy.engine import PolicyEngine class OverviewView(LoginRequiredMixin, TemplateView): diff --git a/passbook/hibp_policy/models.py b/passbook/hibp_policy/models.py index 66da48be4..7936fbb2a 100644 --- a/passbook/hibp_policy/models.py +++ b/passbook/hibp_policy/models.py @@ -6,7 +6,7 @@ from django.db import models from django.utils.translation import gettext as _ from requests import get -from passbook.core.models import Policy, User +from passbook.core.models import Policy, PolicyResult, User LOGGER = getLogger(__name__) @@ -18,13 +18,13 @@ class HaveIBeenPwendPolicy(Policy): form = 'passbook.hibp_policy.forms.HaveIBeenPwnedPolicyForm' - def passes(self, user: User) -> bool: + def passes(self, user: User) -> PolicyResult: """Check if password is in HIBP DB. Hashes given Password with SHA1, uses the first 5 characters of Password in request and checks if full hash is in response. Returns 0 if Password is not in result otherwise the count of how many times it was used.""" # Only check if password is being set if not hasattr(user, '__password__'): - return True + return PolicyResult(True) password = getattr(user, '__password__') pw_hash = sha1(password.encode('utf-8')).hexdigest() # nosec url = 'https://api.pwnedpasswords.com/range/%s' % pw_hash[:5] @@ -36,8 +36,9 @@ class HaveIBeenPwendPolicy(Policy): final_count = int(count) LOGGER.debug("Got count %d for hash %s", final_count, pw_hash[:5]) if final_count > self.allowed_count: - return False, _("Password exists on %(count)d online lists." % {'count': final_count}) - return True + message = _("Password exists on %(count)d online lists." % {'count': final_count}) + return PolicyResult(False, message) + return PolicyResult(True) class Meta: diff --git a/passbook/lib/config.py b/passbook/lib/config.py index 77888d74b..8edffaab8 100644 --- a/passbook/lib/config.py +++ b/passbook/lib/config.py @@ -34,8 +34,7 @@ class ConfigLoader: def __init__(self): super().__init__() - base_dir = os.path.realpath(os.path.join( - os.path.dirname(__file__), '../..')) + base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), '../..')) for path in SEARCH_PATHS: # Check if path is relative, and if so join with base_dir if not os.path.isabs(path): diff --git a/passbook/lib/utils/template.py b/passbook/lib/utils/template.py index 0af5c8e66..af6b492b6 100644 --- a/passbook/lib/utils/template.py +++ b/passbook/lib/utils/template.py @@ -2,9 +2,9 @@ from django.template import Context, Template, loader -def render_from_string(template: str, ctx: Context) -> str: +def render_from_string(tmpl: str, ctx: Context) -> str: """Render template from string to string""" - template = Template(template) + template = Template(tmpl) return template.render(ctx) diff --git a/passbook/oidc_provider/lib.py b/passbook/oidc_provider/lib.py index 07c17edcd..47d809492 100644 --- a/passbook/oidc_provider/lib.py +++ b/passbook/oidc_provider/lib.py @@ -5,7 +5,7 @@ from django.contrib import messages from django.shortcuts import redirect from passbook.core.models import Application -from passbook.core.policies import PolicyEngine +from passbook.policy.engine import PolicyEngine LOGGER = getLogger(__name__) diff --git a/passbook/password_expiry_policy/models.py b/passbook/password_expiry_policy/models.py index 7c5d935e8..93c59e174 100644 --- a/passbook/password_expiry_policy/models.py +++ b/passbook/password_expiry_policy/models.py @@ -6,7 +6,7 @@ from django.db import models from django.utils.timezone import now from django.utils.translation import gettext as _ -from passbook.core.models import Policy, User +from passbook.core.models import Policy, PolicyResult, User LOGGER = getLogger(__name__) @@ -20,7 +20,7 @@ class PasswordExpiryPolicy(Policy): form = 'passbook.password_expiry_policy.forms.PasswordExpiryPolicyForm' - def passes(self, user: User) -> bool: + def passes(self, user: User) -> PolicyResult: """If password change date is more than x days in the past, call set_unusable_password and show a notice""" actual_days = (now() - user.password_change_date).days @@ -29,12 +29,13 @@ class PasswordExpiryPolicy(Policy): if not self.deny_only: user.set_unusable_password() user.save() - return False, _(('Password expired %(days)d days ago. ' - 'Please update your password.') % { - 'days': days_since_expiry - }) - return False, _('Password has expired.') - return True + message = _(('Password expired %(days)d days ago. ' + 'Please update your password.') % { + 'days': days_since_expiry + }) + return PolicyResult(False, message) + return PolicyResult(False, _('Password has expired.')) + return PolicyResult(True) class Meta: diff --git a/passbook/policy/__init__.py b/passbook/policy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/passbook/policy/engine.py b/passbook/policy/engine.py new file mode 100644 index 000000000..92a749c50 --- /dev/null +++ b/passbook/policy/engine.py @@ -0,0 +1,97 @@ +"""passbook policy engine""" +from multiprocessing import Pipe +from multiprocessing.connection import Connection +from typing import List, Tuple + +from django.core.cache import cache +from django.http import HttpRequest +from structlog import get_logger + +from passbook.core.models import Policy, PolicyResult, User +from passbook.policy.task import PolicyTask + +LOGGER = get_logger() + +def _cache_key(policy, user): + return "policy_%s#%s" % (policy.uuid, user.pk) + +class PolicyEngine: + """Orchestrate policy checking, launch tasks and return result""" + + # __group = None + # __cached = None + + policies: List[Policy] = [] + __request: HttpRequest + __user: User + + __proc_list: List[Tuple[Connection, PolicyTask]] = [] + + def __init__(self, policies, user: User = None, request: HttpRequest = None): + self.policies = policies + self.__request = request + self.__user = user + + def for_user(self, user: User) -> 'PolicyEngine': + """Check policies for user""" + self.__user = user + return self + + def with_request(self, request: HttpRequest) -> 'PolicyEngine': + """Set request""" + self.__request = request + return self + + def build(self) -> 'PolicyEngine': + """Build task group""" + if not self.__user: + raise ValueError("User not set.") + cached_policies = [] + kwargs = { + '__password__': getattr(self.__user, '__password__', None), + 'session': dict(getattr(self.__request, 'session', {}).items()), + 'request': self.__request, + } + for policy in self.policies: + cached_policy = cache.get(_cache_key(policy, self.__user), None) + if cached_policy: + LOGGER.debug("Taking result from cache for %s", policy.pk.hex) + cached_policies.append(cached_policy) + else: + LOGGER.debug("Evaluating policy %s", policy.pk.hex) + our_end, task_end = Pipe(False) + task = PolicyTask() + task.ret = task_end + task.user = self.__user + task.policy = policy + task.params = kwargs + LOGGER.debug("Starting Process %s", task.__class__.__name__) + task.start() + self.__proc_list.append((our_end, task)) + # If all policies are cached, we have an empty list here. + if self.__proc_list: + for _, running_proc in self.__proc_list: + running_proc.join() + return self + + @property + def result(self): + """Get policy-checking result""" + results: List[PolicyResult] = [] + messages: List[str] = [] + for our_end, _ in self.__proc_list: + results.append(our_end.recv()) + for policy_result in results: + # passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \ + # (policy_action == Policy.ACTION_DENY and not policy_result) + LOGGER.debug('Result=%r => %r', policy_result, policy_result.passing) + if policy_result.messages: + messages += policy_result.messages + if not policy_result.passing: + return False, messages + return True, messages + + @property + def passing(self): + """Only get true/false if user passes""" + return self.result[0] diff --git a/passbook/policy/task.py b/passbook/policy/task.py new file mode 100644 index 000000000..789140bbf --- /dev/null +++ b/passbook/policy/task.py @@ -0,0 +1,38 @@ +"""passbook policy task""" +from logging import getLogger +from multiprocessing import Process +from multiprocessing.connection import Connection +from typing import Any, Dict + +from passbook.core.models import Policy, User + +LOGGER = getLogger(__name__) + + +def _cache_key(policy, user): + return "policy_%s#%s" % (policy.uuid, user.pk) + +class PolicyTask(Process): + """Evaluate a single policy within a seprate process""" + + ret: Connection + user: User + policy: Policy + params: Dict[str, Any] + + def run(self): + """Task wrapper to run policy checking""" + for key, value in self.params.items(): + setattr(self.user, key, value) + LOGGER.debug("Running policy `%s`#%s for user %s...", self.policy.name, + self.policy.pk.hex, self.user) + policy_result = self.policy.passes(self.user) + # Invert result if policy.negate is set + if self.policy.negate: + policy_result = not policy_result + LOGGER.debug("Policy %r#%s got %s", self.policy.name, self.policy.pk.hex, policy_result) + # cache_key = _cache_key(self.policy, self.user) + # cache.set(cache_key, (self.policy.action, policy_result, message)) + # LOGGER.debug("Cached entry as %s", cache_key) + self.ret.send(policy_result) + self.ret.close() diff --git a/passbook/root/settings.py b/passbook/root/settings.py index 72b464b42..5059c56a9 100644 --- a/passbook/root/settings.py +++ b/passbook/root/settings.py @@ -11,7 +11,6 @@ https://docs.djangoproject.com/en/2.1/ref/settings/ """ import importlib -import logging import os import sys diff --git a/passbook/saml_idp/base.py b/passbook/saml_idp/base.py index ea588786e..d8aa34e6d 100644 --- a/passbook/saml_idp/base.py +++ b/passbook/saml_idp/base.py @@ -260,7 +260,6 @@ class Processor: def _validate_user(self): """Validates the User. Sub-classes should override this and throw an CannotHandleAssertion Exception if the validation does not succeed.""" - pass def can_handle(self, request): """Returns true if this processor can handle this request.""" diff --git a/passbook/saml_idp/exceptions.py b/passbook/saml_idp/exceptions.py index 49b02b42c..98fea6dec 100644 --- a/passbook/saml_idp/exceptions.py +++ b/passbook/saml_idp/exceptions.py @@ -3,9 +3,7 @@ class CannotHandleAssertion(Exception): """This processor does not handle this assertion.""" - pass class UserNotAuthorized(Exception): """User not authorized for SAML 2.0 authentication.""" - pass diff --git a/passbook/saml_idp/views.py b/passbook/saml_idp/views.py index dd4d44314..cdd08324c 100644 --- a/passbook/saml_idp/views.py +++ b/passbook/saml_idp/views.py @@ -16,9 +16,9 @@ from signxml.util import strip_pem_header from passbook.audit.models import AuditEntry from passbook.core.models import Application -from passbook.core.policies import PolicyEngine from passbook.lib.mixins import CSRFExemptMixin from passbook.lib.utils.template import render_to_string +from passbook.policy.engine import PolicyEngine from passbook.saml_idp import exceptions from passbook.saml_idp.models import SAMLProvider diff --git a/passbook/suspicious_policy/models.py b/passbook/suspicious_policy/models.py index 4429d20cf..7fdbea927 100644 --- a/passbook/suspicious_policy/models.py +++ b/passbook/suspicious_policy/models.py @@ -2,7 +2,7 @@ from django.db import models from django.utils.translation import gettext as _ -from passbook.core.models import Policy, User +from passbook.core.models import Policy, PolicyResult, User class SuspiciousRequestPolicy(Policy): @@ -14,7 +14,7 @@ class SuspiciousRequestPolicy(Policy): form = 'passbook.suspicious_policy.forms.SuspiciousRequestPolicyForm' - def passes(self, user: User): + def passes(self, user: User) -> PolicyResult: remote_ip = user.remote_ip passing = True if self.check_ip: @@ -23,7 +23,7 @@ class SuspiciousRequestPolicy(Policy): if self.check_username: user_scores = UserScore.objects.filter(user=user, score__lte=self.threshold) passing = passing and user_scores.exists() - return passing + return PolicyResult(passing) class Meta: