policies: provider raw result for better policy reusability (#5189)

* policies: include raw_result in PolicyResult

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* move ak_call_policy to base evaluator

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-04-06 09:42:29 +02:00 committed by GitHub
parent c117d98e27
commit 977757f561
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 81 additions and 48 deletions

View File

@ -4,7 +4,10 @@ from guardian.shortcuts import get_anonymous_user
from authentik.core.exceptions import PropertyMappingExpressionException from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import PropertyMapping from authentik.core.models import PropertyMapping
from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.policies.expression.models import ExpressionPolicy
class TestPropertyMappings(TestCase): class TestPropertyMappings(TestCase):
@ -12,23 +15,24 @@ class TestPropertyMappings(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.user = create_test_admin_user()
self.factory = RequestFactory() self.factory = RequestFactory()
def test_expression(self): def test_expression(self):
"""Test expression""" """Test expression"""
mapping = PropertyMapping.objects.create(name="test", expression="return 'test'") mapping = PropertyMapping.objects.create(name=generate_id(), expression="return 'test'")
self.assertEqual(mapping.evaluate(None, None), "test") self.assertEqual(mapping.evaluate(None, None), "test")
def test_expression_syntax(self): def test_expression_syntax(self):
"""Test expression syntax error""" """Test expression syntax error"""
mapping = PropertyMapping.objects.create(name="test", expression="-") mapping = PropertyMapping.objects.create(name=generate_id(), expression="-")
with self.assertRaises(PropertyMappingExpressionException): with self.assertRaises(PropertyMappingExpressionException):
mapping.evaluate(None, None) mapping.evaluate(None, None)
def test_expression_error_general(self): def test_expression_error_general(self):
"""Test expression error""" """Test expression error"""
expr = "return aaa" expr = "return aaa"
mapping = PropertyMapping.objects.create(name="test", expression=expr) mapping = PropertyMapping.objects.create(name=generate_id(), expression=expr)
with self.assertRaises(PropertyMappingExpressionException): with self.assertRaises(PropertyMappingExpressionException):
mapping.evaluate(None, None) mapping.evaluate(None, None)
events = Event.objects.filter( events = Event.objects.filter(
@ -41,7 +45,7 @@ class TestPropertyMappings(TestCase):
"""Test expression error (with user and http request""" """Test expression error (with user and http request"""
expr = "return aaa" expr = "return aaa"
request = self.factory.get("/") request = self.factory.get("/")
mapping = PropertyMapping.objects.create(name="test", expression=expr) mapping = PropertyMapping.objects.create(name=generate_id(), expression=expr)
with self.assertRaises(PropertyMappingExpressionException): with self.assertRaises(PropertyMappingExpressionException):
mapping.evaluate(get_anonymous_user(), request) mapping.evaluate(get_anonymous_user(), request)
events = Event.objects.filter( events = Event.objects.filter(
@ -52,3 +56,23 @@ class TestPropertyMappings(TestCase):
event = events.first() event = events.first()
self.assertEqual(event.user["username"], "AnonymousUser") self.assertEqual(event.user["username"], "AnonymousUser")
self.assertEqual(event.client_ip, "127.0.0.1") self.assertEqual(event.client_ip, "127.0.0.1")
def test_call_policy(self):
"""test ak_call_policy"""
expr = ExpressionPolicy.objects.create(
name=generate_id(),
execution_logging=True,
expression="return request.http_request.path",
)
http_request = self.factory.get("/")
tmpl = (
"""
res = ak_call_policy('%s')
result = [request.http_request.path, res.raw_result]
return result
"""
% expr.name
)
evaluator = PropertyMapping(expression=tmpl, name=generate_id())
res = evaluator.evaluate(self.user, http_request)
self.assertEqual(res, ["/", "/"])

View File

@ -8,6 +8,7 @@ from typing import Any, Iterable, Optional
from cachetools import TLRUCache, cached from cachetools import TLRUCache, cached
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django_otp import devices_for_user from django_otp import devices_for_user
from guardian.shortcuts import get_anonymous_user
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from sentry_sdk.hub import Hub from sentry_sdk.hub import Hub
from sentry_sdk.tracing import Span from sentry_sdk.tracing import Span
@ -16,7 +17,9 @@ from structlog.stdlib import get_logger
from authentik.core.models import User from authentik.core.models import User
from authentik.events.models import Event from authentik.events.models import Event
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.policies.types import PolicyRequest from authentik.policies.models import Policy, PolicyBinding
from authentik.policies.process import PolicyProcess
from authentik.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
@ -37,19 +40,20 @@ class BaseEvaluator:
# update website/docs/expressions/_objects.md # update website/docs/expressions/_objects.md
# update website/docs/expressions/_functions.md # update website/docs/expressions/_functions.md
self._globals = { self._globals = {
"regex_match": BaseEvaluator.expr_regex_match, "ak_call_policy": self.expr_func_call_policy,
"regex_replace": BaseEvaluator.expr_regex_replace, "ak_create_event": self.expr_event_create,
"list_flatten": BaseEvaluator.expr_flatten,
"ak_is_group_member": BaseEvaluator.expr_is_group_member, "ak_is_group_member": BaseEvaluator.expr_is_group_member,
"ak_logger": get_logger(self._filename).bind(),
"ak_user_by": BaseEvaluator.expr_user_by, "ak_user_by": BaseEvaluator.expr_user_by,
"ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator, "ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator,
"resolve_dns": BaseEvaluator.expr_resolve_dns,
"reverse_dns": BaseEvaluator.expr_reverse_dns,
"ak_create_event": self.expr_event_create,
"ak_logger": get_logger(self._filename).bind(),
"requests": get_http_session(),
"ip_address": ip_address, "ip_address": ip_address,
"ip_network": ip_network, "ip_network": ip_network,
"list_flatten": BaseEvaluator.expr_flatten,
"regex_match": BaseEvaluator.expr_regex_match,
"regex_replace": BaseEvaluator.expr_regex_replace,
"requests": get_http_session(),
"resolve_dns": BaseEvaluator.expr_resolve_dns,
"reverse_dns": BaseEvaluator.expr_reverse_dns,
} }
self._context = {} self._context = {}
@ -152,6 +156,19 @@ class BaseEvaluator:
return return
event.save() event.save()
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
"""Call policy by name, with current request"""
policy = Policy.objects.filter(name=name).select_subclasses().first()
if not policy:
raise ValueError(f"Policy '{name}' not found.")
user = self._context.get("user", get_anonymous_user())
req = PolicyRequest(user)
if "request" in self._context:
req = self._context["request"]
req.context.update(kwargs)
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
return proc.profiling_wrapper()
def wrap_expression(self, expression: str, params: Iterable[str]) -> str: def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
"""Wrap expression in a function, call it, and save the result as `result`""" """Wrap expression in a function, call it, and save the result as `result`"""
handler_signature = ",".join(params) handler_signature = ",".join(params)

View File

@ -9,8 +9,6 @@ from authentik.flows.planner import PLAN_CONTEXT_SSO
from authentik.lib.expression.evaluator import BaseEvaluator from authentik.lib.expression.evaluator import BaseEvaluator
from authentik.lib.utils.http import get_client_ip from authentik.lib.utils.http import get_client_ip
from authentik.policies.exceptions import PolicyException from authentik.policies.exceptions import PolicyException
from authentik.policies.models import Policy, PolicyBinding
from authentik.policies.process import PolicyProcess
from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
@ -32,22 +30,11 @@ class PolicyEvaluator(BaseEvaluator):
# update website/docs/expressions/_functions.md # update website/docs/expressions/_functions.md
self._context["ak_message"] = self.expr_func_message self._context["ak_message"] = self.expr_func_message
self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator
self._context["ak_call_policy"] = self.expr_func_call_policy
def expr_func_message(self, message: str): def expr_func_message(self, message: str):
"""Wrapper to append to messages list, which is returned with PolicyResult""" """Wrapper to append to messages list, which is returned with PolicyResult"""
self._messages.append(message) self._messages.append(message)
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
"""Call policy by name, with current request"""
policy = Policy.objects.filter(name=name).select_subclasses().first()
if not policy:
raise ValueError(f"Policy '{name}' not found.")
req: PolicyRequest = self._context["request"]
req.context.update(kwargs)
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
return proc.profiling_wrapper()
def set_policy_request(self, request: PolicyRequest): def set_policy_request(self, request: PolicyRequest):
"""Update context based on policy request (if http request is given, update that too)""" """Update context based on policy request (if http request is given, update that too)"""
# update website/docs/expressions/_objects.md # update website/docs/expressions/_objects.md
@ -83,6 +70,7 @@ class PolicyEvaluator(BaseEvaluator):
return PolicyResult(False, str(exc)) return PolicyResult(False, str(exc))
else: else:
policy_result = PolicyResult(False, *self._messages) policy_result = PolicyResult(False, *self._messages)
policy_result.raw_result = result
if result is None: if result is None:
LOGGER.warning( LOGGER.warning(
"Expression policy returned None", "Expression policy returned None",

View File

@ -69,10 +69,11 @@ class PolicyRequest:
@dataclass @dataclass
class PolicyResult: class PolicyResult:
"""Small data-class to hold policy results""" """Result from evaluating a policy."""
passing: bool passing: bool
messages: tuple[str, ...] messages: tuple[str, ...]
raw_result: Any
source_binding: Optional["PolicyBinding"] source_binding: Optional["PolicyBinding"]
source_results: Optional[list["PolicyResult"]] source_results: Optional[list["PolicyResult"]]
@ -83,6 +84,7 @@ class PolicyResult:
super().__init__() super().__init__()
self.passing = passing self.passing = passing
self.messages = messages self.messages = messages
self.raw_result = None
self.source_binding = None self.source_binding = None
self.source_results = [] self.source_results = []
self.log_messages = [] self.log_messages = []

View File

@ -29,6 +29,29 @@ user = list_flatten(["foo"])
# user = "foo" # user = "foo"
``` ```
### `ak_call_policy(name: str, **kwargs) -> PolicyResult`
:::info
Requires authentik 2021.12
:::
Call another policy with the name _name_. Current request is passed to policy. Key-word arguments
can be used to modify the request's context.
Example:
```python
result = ak_call_policy("test-policy")
# result is a PolicyResult object, so you can access `.passing` and `.messages`.
# Starting with authentik 2023.4 you can also access `.raw_result`, which is the raw value returned from the called policy
# `result.passing` will always be a boolean if the policy is passing or not.
return result.passing
result = ak_call_policy("test-policy-2", foo="bar")
# Inside the `test-policy-2` you can then use `request.context["foo"]`
return result.passing
```
### `ak_is_group_member(user: User, **group_filters) -> bool` ### `ak_is_group_member(user: User, **group_filters) -> bool`
Check if `user` is member of a group matching `**group_filters`. Check if `user` is member of a group matching `**group_filters`.

View File

@ -29,27 +29,6 @@ ak_message("Access denied")
return False return False
``` ```
### `ak_call_policy(name: str, **kwargs) -> PolicyResult`
:::info
Requires authentik 2021.12
:::
Call another policy with the name _name_. Current request is passed to policy. Key-word arguments
can be used to modify the request's context.
Example:
```python
result = ak_call_policy("test-policy")
# result is a PolicyResult object, so you can access `.passing` and `.messages`.
return result.passing
result = ak_call_policy("test-policy-2", foo="bar")
# Inside the `test-policy-2` you can then use `request.context["foo"]`
return result.passing
```
import Functions from "../expressions/_functions.md"; import Functions from "../expressions/_functions.md";
<Functions /> <Functions />