policies: provider raw result for better policy reusability (#5189)
* policies: include raw_result in PolicyResult Signed-off-by: Jens Langhammer <jens@goauthentik.io> * move ak_call_policy to base evaluator Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
c117d98e27
commit
977757f561
|
@ -4,7 +4,10 @@ from guardian.shortcuts import get_anonymous_user
|
|||
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import PropertyMapping
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.expression.models import ExpressionPolicy
|
||||
|
||||
|
||||
class TestPropertyMappings(TestCase):
|
||||
|
@ -12,23 +15,24 @@ class TestPropertyMappings(TestCase):
|
|||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.user = create_test_admin_user()
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def test_expression(self):
|
||||
"""Test expression"""
|
||||
mapping = PropertyMapping.objects.create(name="test", expression="return 'test'")
|
||||
mapping = PropertyMapping.objects.create(name=generate_id(), expression="return 'test'")
|
||||
self.assertEqual(mapping.evaluate(None, None), "test")
|
||||
|
||||
def test_expression_syntax(self):
|
||||
"""Test expression syntax error"""
|
||||
mapping = PropertyMapping.objects.create(name="test", expression="-")
|
||||
mapping = PropertyMapping.objects.create(name=generate_id(), expression="-")
|
||||
with self.assertRaises(PropertyMappingExpressionException):
|
||||
mapping.evaluate(None, None)
|
||||
|
||||
def test_expression_error_general(self):
|
||||
"""Test expression error"""
|
||||
expr = "return aaa"
|
||||
mapping = PropertyMapping.objects.create(name="test", expression=expr)
|
||||
mapping = PropertyMapping.objects.create(name=generate_id(), expression=expr)
|
||||
with self.assertRaises(PropertyMappingExpressionException):
|
||||
mapping.evaluate(None, None)
|
||||
events = Event.objects.filter(
|
||||
|
@ -41,7 +45,7 @@ class TestPropertyMappings(TestCase):
|
|||
"""Test expression error (with user and http request"""
|
||||
expr = "return aaa"
|
||||
request = self.factory.get("/")
|
||||
mapping = PropertyMapping.objects.create(name="test", expression=expr)
|
||||
mapping = PropertyMapping.objects.create(name=generate_id(), expression=expr)
|
||||
with self.assertRaises(PropertyMappingExpressionException):
|
||||
mapping.evaluate(get_anonymous_user(), request)
|
||||
events = Event.objects.filter(
|
||||
|
@ -52,3 +56,23 @@ class TestPropertyMappings(TestCase):
|
|||
event = events.first()
|
||||
self.assertEqual(event.user["username"], "AnonymousUser")
|
||||
self.assertEqual(event.client_ip, "127.0.0.1")
|
||||
|
||||
def test_call_policy(self):
|
||||
"""test ak_call_policy"""
|
||||
expr = ExpressionPolicy.objects.create(
|
||||
name=generate_id(),
|
||||
execution_logging=True,
|
||||
expression="return request.http_request.path",
|
||||
)
|
||||
http_request = self.factory.get("/")
|
||||
tmpl = (
|
||||
"""
|
||||
res = ak_call_policy('%s')
|
||||
result = [request.http_request.path, res.raw_result]
|
||||
return result
|
||||
"""
|
||||
% expr.name
|
||||
)
|
||||
evaluator = PropertyMapping(expression=tmpl, name=generate_id())
|
||||
res = evaluator.evaluate(self.user, http_request)
|
||||
self.assertEqual(res, ["/", "/"])
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Iterable, Optional
|
|||
from cachetools import TLRUCache, cached
|
||||
from django.core.exceptions import FieldError
|
||||
from django_otp import devices_for_user
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from rest_framework.serializers import ValidationError
|
||||
from sentry_sdk.hub import Hub
|
||||
from sentry_sdk.tracing import Span
|
||||
|
@ -16,7 +17,9 @@ from structlog.stdlib import get_logger
|
|||
from authentik.core.models import User
|
||||
from authentik.events.models import Event
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.policies.types import PolicyRequest
|
||||
from authentik.policies.models import Policy, PolicyBinding
|
||||
from authentik.policies.process import PolicyProcess
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
@ -37,19 +40,20 @@ class BaseEvaluator:
|
|||
# update website/docs/expressions/_objects.md
|
||||
# update website/docs/expressions/_functions.md
|
||||
self._globals = {
|
||||
"regex_match": BaseEvaluator.expr_regex_match,
|
||||
"regex_replace": BaseEvaluator.expr_regex_replace,
|
||||
"list_flatten": BaseEvaluator.expr_flatten,
|
||||
"ak_call_policy": self.expr_func_call_policy,
|
||||
"ak_create_event": self.expr_event_create,
|
||||
"ak_is_group_member": BaseEvaluator.expr_is_group_member,
|
||||
"ak_logger": get_logger(self._filename).bind(),
|
||||
"ak_user_by": BaseEvaluator.expr_user_by,
|
||||
"ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator,
|
||||
"resolve_dns": BaseEvaluator.expr_resolve_dns,
|
||||
"reverse_dns": BaseEvaluator.expr_reverse_dns,
|
||||
"ak_create_event": self.expr_event_create,
|
||||
"ak_logger": get_logger(self._filename).bind(),
|
||||
"requests": get_http_session(),
|
||||
"ip_address": ip_address,
|
||||
"ip_network": ip_network,
|
||||
"list_flatten": BaseEvaluator.expr_flatten,
|
||||
"regex_match": BaseEvaluator.expr_regex_match,
|
||||
"regex_replace": BaseEvaluator.expr_regex_replace,
|
||||
"requests": get_http_session(),
|
||||
"resolve_dns": BaseEvaluator.expr_resolve_dns,
|
||||
"reverse_dns": BaseEvaluator.expr_reverse_dns,
|
||||
}
|
||||
self._context = {}
|
||||
|
||||
|
@ -152,6 +156,19 @@ class BaseEvaluator:
|
|||
return
|
||||
event.save()
|
||||
|
||||
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
|
||||
"""Call policy by name, with current request"""
|
||||
policy = Policy.objects.filter(name=name).select_subclasses().first()
|
||||
if not policy:
|
||||
raise ValueError(f"Policy '{name}' not found.")
|
||||
user = self._context.get("user", get_anonymous_user())
|
||||
req = PolicyRequest(user)
|
||||
if "request" in self._context:
|
||||
req = self._context["request"]
|
||||
req.context.update(kwargs)
|
||||
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
|
||||
return proc.profiling_wrapper()
|
||||
|
||||
def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
|
||||
"""Wrap expression in a function, call it, and save the result as `result`"""
|
||||
handler_signature = ",".join(params)
|
||||
|
|
|
@ -9,8 +9,6 @@ from authentik.flows.planner import PLAN_CONTEXT_SSO
|
|||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
from authentik.lib.utils.http import get_client_ip
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
from authentik.policies.models import Policy, PolicyBinding
|
||||
from authentik.policies.process import PolicyProcess
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
@ -32,22 +30,11 @@ class PolicyEvaluator(BaseEvaluator):
|
|||
# update website/docs/expressions/_functions.md
|
||||
self._context["ak_message"] = self.expr_func_message
|
||||
self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator
|
||||
self._context["ak_call_policy"] = self.expr_func_call_policy
|
||||
|
||||
def expr_func_message(self, message: str):
|
||||
"""Wrapper to append to messages list, which is returned with PolicyResult"""
|
||||
self._messages.append(message)
|
||||
|
||||
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
|
||||
"""Call policy by name, with current request"""
|
||||
policy = Policy.objects.filter(name=name).select_subclasses().first()
|
||||
if not policy:
|
||||
raise ValueError(f"Policy '{name}' not found.")
|
||||
req: PolicyRequest = self._context["request"]
|
||||
req.context.update(kwargs)
|
||||
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
|
||||
return proc.profiling_wrapper()
|
||||
|
||||
def set_policy_request(self, request: PolicyRequest):
|
||||
"""Update context based on policy request (if http request is given, update that too)"""
|
||||
# update website/docs/expressions/_objects.md
|
||||
|
@ -83,6 +70,7 @@ class PolicyEvaluator(BaseEvaluator):
|
|||
return PolicyResult(False, str(exc))
|
||||
else:
|
||||
policy_result = PolicyResult(False, *self._messages)
|
||||
policy_result.raw_result = result
|
||||
if result is None:
|
||||
LOGGER.warning(
|
||||
"Expression policy returned None",
|
||||
|
|
|
@ -69,10 +69,11 @@ class PolicyRequest:
|
|||
|
||||
@dataclass
|
||||
class PolicyResult:
|
||||
"""Small data-class to hold policy results"""
|
||||
"""Result from evaluating a policy."""
|
||||
|
||||
passing: bool
|
||||
messages: tuple[str, ...]
|
||||
raw_result: Any
|
||||
|
||||
source_binding: Optional["PolicyBinding"]
|
||||
source_results: Optional[list["PolicyResult"]]
|
||||
|
@ -83,6 +84,7 @@ class PolicyResult:
|
|||
super().__init__()
|
||||
self.passing = passing
|
||||
self.messages = messages
|
||||
self.raw_result = None
|
||||
self.source_binding = None
|
||||
self.source_results = []
|
||||
self.log_messages = []
|
||||
|
|
|
@ -29,6 +29,29 @@ user = list_flatten(["foo"])
|
|||
# user = "foo"
|
||||
```
|
||||
|
||||
### `ak_call_policy(name: str, **kwargs) -> PolicyResult`
|
||||
|
||||
:::info
|
||||
Requires authentik 2021.12
|
||||
:::
|
||||
|
||||
Call another policy with the name _name_. Current request is passed to policy. Key-word arguments
|
||||
can be used to modify the request's context.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
result = ak_call_policy("test-policy")
|
||||
# result is a PolicyResult object, so you can access `.passing` and `.messages`.
|
||||
# Starting with authentik 2023.4 you can also access `.raw_result`, which is the raw value returned from the called policy
|
||||
# `result.passing` will always be a boolean if the policy is passing or not.
|
||||
return result.passing
|
||||
|
||||
result = ak_call_policy("test-policy-2", foo="bar")
|
||||
# Inside the `test-policy-2` you can then use `request.context["foo"]`
|
||||
return result.passing
|
||||
```
|
||||
|
||||
### `ak_is_group_member(user: User, **group_filters) -> bool`
|
||||
|
||||
Check if `user` is member of a group matching `**group_filters`.
|
||||
|
|
|
@ -29,27 +29,6 @@ ak_message("Access denied")
|
|||
return False
|
||||
```
|
||||
|
||||
### `ak_call_policy(name: str, **kwargs) -> PolicyResult`
|
||||
|
||||
:::info
|
||||
Requires authentik 2021.12
|
||||
:::
|
||||
|
||||
Call another policy with the name _name_. Current request is passed to policy. Key-word arguments
|
||||
can be used to modify the request's context.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
result = ak_call_policy("test-policy")
|
||||
# result is a PolicyResult object, so you can access `.passing` and `.messages`.
|
||||
return result.passing
|
||||
|
||||
result = ak_call_policy("test-policy-2", foo="bar")
|
||||
# Inside the `test-policy-2` you can then use `request.context["foo"]`
|
||||
return result.passing
|
||||
```
|
||||
|
||||
import Functions from "../expressions/_functions.md";
|
||||
|
||||
<Functions />
|
||||
|
|
Reference in a new issue