From 2b5fddb7bfca967bebc1239ef96c1cd9e01331f8 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Sun, 23 Feb 2020 15:54:26 +0100 Subject: [PATCH] policies: add unittests for evaluator --- passbook/policies/expression/evaluator.py | 17 +++--- .../policies/expression/tests/__init__.py | 0 .../expression/tests/test_evaluator.py | 58 +++++++++++++++++++ passbook/policies/types.py | 8 ++- 4 files changed, 72 insertions(+), 11 deletions(-) create mode 100644 passbook/policies/expression/tests/__init__.py create mode 100644 passbook/policies/expression/tests/test_evaluator.py diff --git a/passbook/policies/expression/evaluator.py b/passbook/policies/expression/evaluator.py index 39943202d..2a0e0193a 100644 --- a/passbook/policies/expression/evaluator.py +++ b/passbook/policies/expression/evaluator.py @@ -19,7 +19,7 @@ LOGGER = get_logger() class Evaluator: - """Validate and evaulate jinja2-based expressions""" + """Validate and evaluate jinja2-based expressions""" _env: NativeEnvironment @@ -51,14 +51,15 @@ class Evaluator: """Return dictionary with additional global variables passed to expression""" # update passbook/policies/expression/templates/policy/expression/form.html # update docs/policies/expression/index.md - kwargs["pb_is_sso_flow"] = request.http_request.session.get( - AuthenticationView.SESSION_IS_SSO_LOGIN, False - ) kwargs["pb_is_group_member"] = Evaluator.jinja2_func_is_group_member kwargs["pb_logger"] = get_logger() - kwargs["pb_client_ip"] = ( - get_client_ip(request.http_request) or "255.255.255.255" - ) + if request.http_request: + kwargs["pb_is_sso_flow"] = request.http_request.session.get( + AuthenticationView.SESSION_IS_SSO_LOGIN, False + ) + kwargs["pb_client_ip"] = ( + get_client_ip(request.http_request) or "255.255.255.255" + ) return kwargs def evaluate(self, expression_source: str, request: PolicyRequest) -> PolicyResult: @@ -81,7 +82,7 @@ class Evaluator: req=request, ) return PolicyResult(False) - if isinstance(result, list) and len(result) == 2: + if isinstance(result, (list, tuple)) and len(result) == 2: return PolicyResult(*result) if result: return PolicyResult(result) diff --git a/passbook/policies/expression/tests/__init__.py b/passbook/policies/expression/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/passbook/policies/expression/tests/test_evaluator.py b/passbook/policies/expression/tests/test_evaluator.py new file mode 100644 index 000000000..ca22e86ec --- /dev/null +++ b/passbook/policies/expression/tests/test_evaluator.py @@ -0,0 +1,58 @@ +"""evaluator tests""" +from django.core.exceptions import ValidationError +from django.test import TestCase +from guardian.shortcuts import get_anonymous_user + +from passbook.policies.expression.evaluator import Evaluator +from passbook.policies.types import PolicyRequest + + +class TestEvaluator(TestCase): + """Evaluator tests""" + + def setUp(self): + self.request = PolicyRequest(user=get_anonymous_user()) + + def test_valid(self): + """test simple value expression""" + template = "True" + evaluator = Evaluator() + self.assertEqual(evaluator.evaluate(template, self.request).passing, True) + + def test_messages(self): + """test expression with message return""" + template = "False, 'some message'" + evaluator = Evaluator() + result = evaluator.evaluate(template, self.request) + self.assertEqual(result.passing, False) + self.assertEqual(result.messages, ("some message",)) + + def test_invalid_syntax(self): + """test invalid syntax""" + template = "{%" + evaluator = Evaluator() + result = evaluator.evaluate(template, self.request) + self.assertEqual(result.passing, False) + self.assertEqual(result.messages, ("tag name expected",)) + + def test_undefined(self): + """test undefined result""" + template = "{{ foo.bar }}" + evaluator = Evaluator() + result = evaluator.evaluate(template, self.request) + self.assertEqual(result.passing, False) + self.assertEqual(result.messages, ("'foo' is undefined",)) + + def test_validate(self): + """test validate""" + template = "True" + evaluator = Evaluator() + result = evaluator.validate(template) + self.assertEqual(result, True) + + def test_validate_invalid(self): + """test validate""" + template = "{%" + evaluator = Evaluator() + with self.assertRaises(ValidationError): + evaluator.validate(template) diff --git a/passbook/policies/types.py b/passbook/policies/types.py index 99bb890cc..fd549589a 100644 --- a/passbook/policies/types.py +++ b/passbook/policies/types.py @@ -1,7 +1,7 @@ """policy structures""" from __future__ import annotations -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from django.db.models import Model from django.http import HttpRequest @@ -14,11 +14,13 @@ class PolicyRequest: """Data-class to hold policy request data""" user: User - http_request: HttpRequest - obj: Model + http_request: Optional[HttpRequest] + obj: Optional[Model] def __init__(self, user: User): self.user = user + self.http_request = None + self.obj = None def __str__(self): return f""