diff --git a/passbook/core/models.py b/passbook/core/models.py index 08ad322dc..13517ea0e 100644 --- a/passbook/core/models.py +++ b/passbook/core/models.py @@ -13,7 +13,6 @@ from django.utils.translation import gettext_lazy as _ from guardian.mixins import GuardianUserMixin from jinja2 import Undefined from jinja2.exceptions import TemplateSyntaxError, UndefinedError -from jinja2.nativetypes import NativeEnvironment from model_utils.managers import InheritanceManager from structlog import get_logger @@ -24,7 +23,6 @@ from passbook.lib.models import CreatedUpdatedModel from passbook.policies.models import PolicyBindingModel LOGGER = get_logger() -NATIVE_ENVIRONMENT = NativeEnvironment() def default_token_duration(): @@ -208,8 +206,11 @@ class PropertyMapping(models.Model): self, user: Optional[User], request: Optional[HttpRequest], **kwargs ) -> Any: """Evaluate `self.expression` using `**kwargs` as Context.""" + from passbook.policies.expression.evaluator import Evaluator + + evaluator = Evaluator() try: - expression = NATIVE_ENVIRONMENT.from_string(self.expression) + expression = evaluator.env.from_string(self.expression) except TemplateSyntaxError as exc: raise PropertyMappingExpressionException from exc try: @@ -221,8 +222,11 @@ class PropertyMapping(models.Model): raise PropertyMappingExpressionException from exc def save(self, *args, **kwargs): + from passbook.policies.expression.evaluator import Evaluator + + evaluator = Evaluator() try: - NATIVE_ENVIRONMENT.from_string(self.expression) + evaluator.env.from_string(self.expression) except TemplateSyntaxError as exc: raise ValidationError("Expression Syntax Error") from exc return super().save(*args, **kwargs) diff --git a/passbook/policies/expression/evaluator.py b/passbook/policies/expression/evaluator.py index b2120bb74..047c5f1ac 100644 --- a/passbook/policies/expression/evaluator.py +++ b/passbook/policies/expression/evaluator.py @@ -1,8 +1,9 @@ """passbook expression policy evaluator""" import re -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from django.core.exceptions import ValidationError +from django.http import HttpRequest from jinja2 import Undefined from jinja2.exceptions import TemplateSyntaxError from jinja2.nativetypes import NativeEnvironment @@ -25,12 +26,32 @@ class Evaluator: _env: NativeEnvironment + _context: Dict[str, Any] + _messages: List[str] + def __init__(self): - self._env = NativeEnvironment() + self._env = NativeEnvironment( + extensions=["jinja2.ext.do",], + trim_blocks=True, + lstrip_blocks=True, + line_statement_prefix=">", + ) # update passbook/policies/expression/templates/policy/expression/form.html # update docs/policies/expression/index.md self._env.filters["regex_match"] = Evaluator.jinja2_filter_regex_match self._env.filters["regex_replace"] = Evaluator.jinja2_filter_regex_replace + self._env.globals["pb_message"] = self.jinja2_func_message + self._context = { + "pb_is_group_member": Evaluator.jinja2_func_is_group_member, + "pb_logger": get_logger(), + "requests": Session(), + } + self._messages = [] + + @property + def env(self) -> NativeEnvironment: + """Access to our custom NativeEnvironment""" + return self._env @staticmethod def jinja2_filter_regex_match(value: Any, regex: str) -> bool: @@ -47,52 +68,57 @@ class Evaluator: """Check if `user` is member of group with name `group_name`""" return user.groups.filter(name=group_name).exists() - def _get_expression_context( - self, request: PolicyRequest, **kwargs - ) -> Dict[str, Any]: - """Return dictionary with additional global variables passed to expression""" + def jinja2_func_message(self, message: str): + """Wrapper to append to messages list, which is returned with PolicyResult""" + self._messages.append(message) + + def set_policy_request(self, request: PolicyRequest): + """Update context based on policy request (if http request is given, update that too)""" # update passbook/policies/expression/templates/policy/expression/form.html # update docs/policies/expression/index.md - kwargs["pb_is_group_member"] = Evaluator.jinja2_func_is_group_member - kwargs["pb_logger"] = get_logger() - kwargs["requests"] = Session() - kwargs["pb_is_sso_flow"] = request.context.get(PLAN_CONTEXT_SSO, False) + self._context["pb_is_sso_flow"] = request.context.get(PLAN_CONTEXT_SSO, False) + self._context["request"] = request if request.http_request: - kwargs["pb_client_ip"] = ( - get_client_ip(request.http_request) or "255.255.255.255" - ) - if SESSION_KEY_PLAN in request.http_request.session: - kwargs["pb_flow_plan"] = request.http_request.session[SESSION_KEY_PLAN] - return kwargs + self.set_http_request(request.http_request) - def evaluate(self, expression_source: str, request: PolicyRequest) -> PolicyResult: - """Parse and evaluate expression. - If the Expression evaluates to a list with 2 items, the first is used as passing bool and - the second as messages. - If the Expression evaluates to a truthy-object, it is used as passing bool.""" + def set_http_request(self, request: HttpRequest): + """Update context based on http request""" + # update passbook/policies/expression/templates/policy/expression/form.html + # update docs/policies/expression/index.md + self._context["pb_client_ip"] = ( + get_client_ip(request.http_request) or "255.255.255.255" + ) + self._context["request"] = request + if SESSION_KEY_PLAN in request.http_request.session: + self._context["pb_flow_plan"] = request.http_request.session[ + SESSION_KEY_PLAN + ] + + def evaluate(self, expression_source: str) -> PolicyResult: + """Parse and evaluate expression. Policy is expected to return a truthy object. + Messages can be added using 'do pb_message()'.""" try: - expression = self._env.from_string(expression_source) + expression = self._env.from_string(expression_source.lstrip().rstrip()) except TemplateSyntaxError as exc: return PolicyResult(False, str(exc)) try: - result: Optional[Any] = expression.render( - request=request, **self._get_expression_context(request) - ) + result: Optional[Any] = expression.render(self._context) + except Exception as exc: # pylint: disable=broad-except + LOGGER.warning("Expression error", exc=exc) + return PolicyResult(False, str(exc)) + else: + policy_result = PolicyResult(False) + policy_result.messages = tuple(self._messages) if isinstance(result, Undefined): LOGGER.warning( "Expression policy returned undefined", src=expression_source, - req=request, + req=self._context, ) - return PolicyResult(False) - if isinstance(result, (list, tuple)) and len(result) == 2: - return PolicyResult(*result) + policy_result.passing = False if result: - return PolicyResult(bool(result)) - return PolicyResult(False) - except Exception as exc: # pylint: disable=broad-except - LOGGER.warning("Expression error", exc=exc) - return PolicyResult(False, str(exc)) + policy_result.passing = bool(result) + return policy_result def validate(self, expression: str): """Validate expression's syntax, raise ValidationError if Syntax is invalid""" diff --git a/passbook/policies/expression/models.py b/passbook/policies/expression/models.py index edf9b629b..e43f17560 100644 --- a/passbook/policies/expression/models.py +++ b/passbook/policies/expression/models.py @@ -16,7 +16,9 @@ class ExpressionPolicy(Policy): def passes(self, request: PolicyRequest) -> PolicyResult: """Evaluate and render expression. Returns PolicyResult(false) on error.""" - return Evaluator().evaluate(self.expression, request) + evaluator = Evaluator() + evaluator.set_policy_request(request) + return evaluator.evaluate(self.expression) def save(self, *args, **kwargs): Evaluator().validate(self.expression) diff --git a/passbook/policies/expression/tests/test_evaluator.py b/passbook/policies/expression/tests/test_evaluator.py index ca22e86ec..e39cdfec9 100644 --- a/passbook/policies/expression/tests/test_evaluator.py +++ b/passbook/policies/expression/tests/test_evaluator.py @@ -17,13 +17,15 @@ class TestEvaluator(TestCase): """test simple value expression""" template = "True" evaluator = Evaluator() - self.assertEqual(evaluator.evaluate(template, self.request).passing, True) + evaluator.set_policy_request(self.request) + self.assertEqual(evaluator.evaluate(template).passing, True) def test_messages(self): """test expression with message return""" - template = "False, 'some message'" + template = '{% do pb_message("some message") %}False' evaluator = Evaluator() - result = evaluator.evaluate(template, self.request) + evaluator.set_policy_request(self.request) + result = evaluator.evaluate(template) self.assertEqual(result.passing, False) self.assertEqual(result.messages, ("some message",)) @@ -31,7 +33,8 @@ class TestEvaluator(TestCase): """test invalid syntax""" template = "{%" evaluator = Evaluator() - result = evaluator.evaluate(template, self.request) + evaluator.set_policy_request(self.request) + result = evaluator.evaluate(template) self.assertEqual(result.passing, False) self.assertEqual(result.messages, ("tag name expected",)) @@ -39,7 +42,8 @@ class TestEvaluator(TestCase): """test undefined result""" template = "{{ foo.bar }}" evaluator = Evaluator() - result = evaluator.evaluate(template, self.request) + evaluator.set_policy_request(self.request) + result = evaluator.evaluate(template) self.assertEqual(result.passing, False) self.assertEqual(result.messages, ("'foo' is undefined",))