policies: provider raw result for better policy reusability (#5189)
* policies: include raw_result in PolicyResult Signed-off-by: Jens Langhammer <jens@goauthentik.io> * move ak_call_policy to base evaluator Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
c117d98e27
commit
977757f561
|
@ -4,7 +4,10 @@ from guardian.shortcuts import get_anonymous_user
|
||||||
|
|
||||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||||
from authentik.core.models import PropertyMapping
|
from authentik.core.models import PropertyMapping
|
||||||
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
|
from authentik.policies.expression.models import ExpressionPolicy
|
||||||
|
|
||||||
|
|
||||||
class TestPropertyMappings(TestCase):
|
class TestPropertyMappings(TestCase):
|
||||||
|
@ -12,23 +15,24 @@ class TestPropertyMappings(TestCase):
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.user = create_test_admin_user()
|
||||||
self.factory = RequestFactory()
|
self.factory = RequestFactory()
|
||||||
|
|
||||||
def test_expression(self):
|
def test_expression(self):
|
||||||
"""Test expression"""
|
"""Test expression"""
|
||||||
mapping = PropertyMapping.objects.create(name="test", expression="return 'test'")
|
mapping = PropertyMapping.objects.create(name=generate_id(), expression="return 'test'")
|
||||||
self.assertEqual(mapping.evaluate(None, None), "test")
|
self.assertEqual(mapping.evaluate(None, None), "test")
|
||||||
|
|
||||||
def test_expression_syntax(self):
|
def test_expression_syntax(self):
|
||||||
"""Test expression syntax error"""
|
"""Test expression syntax error"""
|
||||||
mapping = PropertyMapping.objects.create(name="test", expression="-")
|
mapping = PropertyMapping.objects.create(name=generate_id(), expression="-")
|
||||||
with self.assertRaises(PropertyMappingExpressionException):
|
with self.assertRaises(PropertyMappingExpressionException):
|
||||||
mapping.evaluate(None, None)
|
mapping.evaluate(None, None)
|
||||||
|
|
||||||
def test_expression_error_general(self):
|
def test_expression_error_general(self):
|
||||||
"""Test expression error"""
|
"""Test expression error"""
|
||||||
expr = "return aaa"
|
expr = "return aaa"
|
||||||
mapping = PropertyMapping.objects.create(name="test", expression=expr)
|
mapping = PropertyMapping.objects.create(name=generate_id(), expression=expr)
|
||||||
with self.assertRaises(PropertyMappingExpressionException):
|
with self.assertRaises(PropertyMappingExpressionException):
|
||||||
mapping.evaluate(None, None)
|
mapping.evaluate(None, None)
|
||||||
events = Event.objects.filter(
|
events = Event.objects.filter(
|
||||||
|
@ -41,7 +45,7 @@ class TestPropertyMappings(TestCase):
|
||||||
"""Test expression error (with user and http request"""
|
"""Test expression error (with user and http request"""
|
||||||
expr = "return aaa"
|
expr = "return aaa"
|
||||||
request = self.factory.get("/")
|
request = self.factory.get("/")
|
||||||
mapping = PropertyMapping.objects.create(name="test", expression=expr)
|
mapping = PropertyMapping.objects.create(name=generate_id(), expression=expr)
|
||||||
with self.assertRaises(PropertyMappingExpressionException):
|
with self.assertRaises(PropertyMappingExpressionException):
|
||||||
mapping.evaluate(get_anonymous_user(), request)
|
mapping.evaluate(get_anonymous_user(), request)
|
||||||
events = Event.objects.filter(
|
events = Event.objects.filter(
|
||||||
|
@ -52,3 +56,23 @@ class TestPropertyMappings(TestCase):
|
||||||
event = events.first()
|
event = events.first()
|
||||||
self.assertEqual(event.user["username"], "AnonymousUser")
|
self.assertEqual(event.user["username"], "AnonymousUser")
|
||||||
self.assertEqual(event.client_ip, "127.0.0.1")
|
self.assertEqual(event.client_ip, "127.0.0.1")
|
||||||
|
|
||||||
|
def test_call_policy(self):
|
||||||
|
"""test ak_call_policy"""
|
||||||
|
expr = ExpressionPolicy.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
execution_logging=True,
|
||||||
|
expression="return request.http_request.path",
|
||||||
|
)
|
||||||
|
http_request = self.factory.get("/")
|
||||||
|
tmpl = (
|
||||||
|
"""
|
||||||
|
res = ak_call_policy('%s')
|
||||||
|
result = [request.http_request.path, res.raw_result]
|
||||||
|
return result
|
||||||
|
"""
|
||||||
|
% expr.name
|
||||||
|
)
|
||||||
|
evaluator = PropertyMapping(expression=tmpl, name=generate_id())
|
||||||
|
res = evaluator.evaluate(self.user, http_request)
|
||||||
|
self.assertEqual(res, ["/", "/"])
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Iterable, Optional
|
||||||
from cachetools import TLRUCache, cached
|
from cachetools import TLRUCache, cached
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django_otp import devices_for_user
|
from django_otp import devices_for_user
|
||||||
|
from guardian.shortcuts import get_anonymous_user
|
||||||
from rest_framework.serializers import ValidationError
|
from rest_framework.serializers import ValidationError
|
||||||
from sentry_sdk.hub import Hub
|
from sentry_sdk.hub import Hub
|
||||||
from sentry_sdk.tracing import Span
|
from sentry_sdk.tracing import Span
|
||||||
|
@ -16,7 +17,9 @@ from structlog.stdlib import get_logger
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.events.models import Event
|
from authentik.events.models import Event
|
||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.policies.types import PolicyRequest
|
from authentik.policies.models import Policy, PolicyBinding
|
||||||
|
from authentik.policies.process import PolicyProcess
|
||||||
|
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
@ -37,19 +40,20 @@ class BaseEvaluator:
|
||||||
# update website/docs/expressions/_objects.md
|
# update website/docs/expressions/_objects.md
|
||||||
# update website/docs/expressions/_functions.md
|
# update website/docs/expressions/_functions.md
|
||||||
self._globals = {
|
self._globals = {
|
||||||
"regex_match": BaseEvaluator.expr_regex_match,
|
"ak_call_policy": self.expr_func_call_policy,
|
||||||
"regex_replace": BaseEvaluator.expr_regex_replace,
|
"ak_create_event": self.expr_event_create,
|
||||||
"list_flatten": BaseEvaluator.expr_flatten,
|
|
||||||
"ak_is_group_member": BaseEvaluator.expr_is_group_member,
|
"ak_is_group_member": BaseEvaluator.expr_is_group_member,
|
||||||
|
"ak_logger": get_logger(self._filename).bind(),
|
||||||
"ak_user_by": BaseEvaluator.expr_user_by,
|
"ak_user_by": BaseEvaluator.expr_user_by,
|
||||||
"ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator,
|
"ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator,
|
||||||
"resolve_dns": BaseEvaluator.expr_resolve_dns,
|
|
||||||
"reverse_dns": BaseEvaluator.expr_reverse_dns,
|
|
||||||
"ak_create_event": self.expr_event_create,
|
|
||||||
"ak_logger": get_logger(self._filename).bind(),
|
|
||||||
"requests": get_http_session(),
|
|
||||||
"ip_address": ip_address,
|
"ip_address": ip_address,
|
||||||
"ip_network": ip_network,
|
"ip_network": ip_network,
|
||||||
|
"list_flatten": BaseEvaluator.expr_flatten,
|
||||||
|
"regex_match": BaseEvaluator.expr_regex_match,
|
||||||
|
"regex_replace": BaseEvaluator.expr_regex_replace,
|
||||||
|
"requests": get_http_session(),
|
||||||
|
"resolve_dns": BaseEvaluator.expr_resolve_dns,
|
||||||
|
"reverse_dns": BaseEvaluator.expr_reverse_dns,
|
||||||
}
|
}
|
||||||
self._context = {}
|
self._context = {}
|
||||||
|
|
||||||
|
@ -152,6 +156,19 @@ class BaseEvaluator:
|
||||||
return
|
return
|
||||||
event.save()
|
event.save()
|
||||||
|
|
||||||
|
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
|
||||||
|
"""Call policy by name, with current request"""
|
||||||
|
policy = Policy.objects.filter(name=name).select_subclasses().first()
|
||||||
|
if not policy:
|
||||||
|
raise ValueError(f"Policy '{name}' not found.")
|
||||||
|
user = self._context.get("user", get_anonymous_user())
|
||||||
|
req = PolicyRequest(user)
|
||||||
|
if "request" in self._context:
|
||||||
|
req = self._context["request"]
|
||||||
|
req.context.update(kwargs)
|
||||||
|
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
|
||||||
|
return proc.profiling_wrapper()
|
||||||
|
|
||||||
def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
|
def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
|
||||||
"""Wrap expression in a function, call it, and save the result as `result`"""
|
"""Wrap expression in a function, call it, and save the result as `result`"""
|
||||||
handler_signature = ",".join(params)
|
handler_signature = ",".join(params)
|
||||||
|
|
|
@ -9,8 +9,6 @@ from authentik.flows.planner import PLAN_CONTEXT_SSO
|
||||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||||
from authentik.lib.utils.http import get_client_ip
|
from authentik.lib.utils.http import get_client_ip
|
||||||
from authentik.policies.exceptions import PolicyException
|
from authentik.policies.exceptions import PolicyException
|
||||||
from authentik.policies.models import Policy, PolicyBinding
|
|
||||||
from authentik.policies.process import PolicyProcess
|
|
||||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
@ -32,22 +30,11 @@ class PolicyEvaluator(BaseEvaluator):
|
||||||
# update website/docs/expressions/_functions.md
|
# update website/docs/expressions/_functions.md
|
||||||
self._context["ak_message"] = self.expr_func_message
|
self._context["ak_message"] = self.expr_func_message
|
||||||
self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator
|
self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator
|
||||||
self._context["ak_call_policy"] = self.expr_func_call_policy
|
|
||||||
|
|
||||||
def expr_func_message(self, message: str):
|
def expr_func_message(self, message: str):
|
||||||
"""Wrapper to append to messages list, which is returned with PolicyResult"""
|
"""Wrapper to append to messages list, which is returned with PolicyResult"""
|
||||||
self._messages.append(message)
|
self._messages.append(message)
|
||||||
|
|
||||||
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
|
|
||||||
"""Call policy by name, with current request"""
|
|
||||||
policy = Policy.objects.filter(name=name).select_subclasses().first()
|
|
||||||
if not policy:
|
|
||||||
raise ValueError(f"Policy '{name}' not found.")
|
|
||||||
req: PolicyRequest = self._context["request"]
|
|
||||||
req.context.update(kwargs)
|
|
||||||
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
|
|
||||||
return proc.profiling_wrapper()
|
|
||||||
|
|
||||||
def set_policy_request(self, request: PolicyRequest):
|
def set_policy_request(self, request: PolicyRequest):
|
||||||
"""Update context based on policy request (if http request is given, update that too)"""
|
"""Update context based on policy request (if http request is given, update that too)"""
|
||||||
# update website/docs/expressions/_objects.md
|
# update website/docs/expressions/_objects.md
|
||||||
|
@ -83,6 +70,7 @@ class PolicyEvaluator(BaseEvaluator):
|
||||||
return PolicyResult(False, str(exc))
|
return PolicyResult(False, str(exc))
|
||||||
else:
|
else:
|
||||||
policy_result = PolicyResult(False, *self._messages)
|
policy_result = PolicyResult(False, *self._messages)
|
||||||
|
policy_result.raw_result = result
|
||||||
if result is None:
|
if result is None:
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
"Expression policy returned None",
|
"Expression policy returned None",
|
||||||
|
|
|
@ -69,10 +69,11 @@ class PolicyRequest:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PolicyResult:
|
class PolicyResult:
|
||||||
"""Small data-class to hold policy results"""
|
"""Result from evaluating a policy."""
|
||||||
|
|
||||||
passing: bool
|
passing: bool
|
||||||
messages: tuple[str, ...]
|
messages: tuple[str, ...]
|
||||||
|
raw_result: Any
|
||||||
|
|
||||||
source_binding: Optional["PolicyBinding"]
|
source_binding: Optional["PolicyBinding"]
|
||||||
source_results: Optional[list["PolicyResult"]]
|
source_results: Optional[list["PolicyResult"]]
|
||||||
|
@ -83,6 +84,7 @@ class PolicyResult:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.passing = passing
|
self.passing = passing
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
|
self.raw_result = None
|
||||||
self.source_binding = None
|
self.source_binding = None
|
||||||
self.source_results = []
|
self.source_results = []
|
||||||
self.log_messages = []
|
self.log_messages = []
|
||||||
|
|
|
@ -29,6 +29,29 @@ user = list_flatten(["foo"])
|
||||||
# user = "foo"
|
# user = "foo"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `ak_call_policy(name: str, **kwargs) -> PolicyResult`
|
||||||
|
|
||||||
|
:::info
|
||||||
|
Requires authentik 2021.12
|
||||||
|
:::
|
||||||
|
|
||||||
|
Call another policy with the name _name_. Current request is passed to policy. Key-word arguments
|
||||||
|
can be used to modify the request's context.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = ak_call_policy("test-policy")
|
||||||
|
# result is a PolicyResult object, so you can access `.passing` and `.messages`.
|
||||||
|
# Starting with authentik 2023.4 you can also access `.raw_result`, which is the raw value returned from the called policy
|
||||||
|
# `result.passing` will always be a boolean if the policy is passing or not.
|
||||||
|
return result.passing
|
||||||
|
|
||||||
|
result = ak_call_policy("test-policy-2", foo="bar")
|
||||||
|
# Inside the `test-policy-2` you can then use `request.context["foo"]`
|
||||||
|
return result.passing
|
||||||
|
```
|
||||||
|
|
||||||
### `ak_is_group_member(user: User, **group_filters) -> bool`
|
### `ak_is_group_member(user: User, **group_filters) -> bool`
|
||||||
|
|
||||||
Check if `user` is member of a group matching `**group_filters`.
|
Check if `user` is member of a group matching `**group_filters`.
|
||||||
|
|
|
@ -29,27 +29,6 @@ ak_message("Access denied")
|
||||||
return False
|
return False
|
||||||
```
|
```
|
||||||
|
|
||||||
### `ak_call_policy(name: str, **kwargs) -> PolicyResult`
|
|
||||||
|
|
||||||
:::info
|
|
||||||
Requires authentik 2021.12
|
|
||||||
:::
|
|
||||||
|
|
||||||
Call another policy with the name _name_. Current request is passed to policy. Key-word arguments
|
|
||||||
can be used to modify the request's context.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
result = ak_call_policy("test-policy")
|
|
||||||
# result is a PolicyResult object, so you can access `.passing` and `.messages`.
|
|
||||||
return result.passing
|
|
||||||
|
|
||||||
result = ak_call_policy("test-policy-2", foo="bar")
|
|
||||||
# Inside the `test-policy-2` you can then use `request.context["foo"]`
|
|
||||||
return result.passing
|
|
||||||
```
|
|
||||||
|
|
||||||
import Functions from "../expressions/_functions.md";
|
import Functions from "../expressions/_functions.md";
|
||||||
|
|
||||||
<Functions />
|
<Functions />
|
||||||
|
|
Reference in New Issue