policies/event_matcher: simplify validity checking

This commit is contained in:
Jens Langhammer 2021-01-15 11:26:55 +01:00
parent f297d1256d
commit 2e42da11ea
3 changed files with 37 additions and 36 deletions

View file

@ -74,12 +74,12 @@ class EventMatcherPolicy(Policy):
if "event" not in request.context: if "event" not in request.context:
return PolicyResult(False) return PolicyResult(False)
event: Event = request.context["event"] event: Event = request.context["event"]
if event.action != self.action: if event.action == self.action:
return PolicyResult(True, "Action matchede.") return PolicyResult(True, "Action matched.")
if event.client_ip != self.client_ip: if event.client_ip == self.client_ip:
return PolicyResult(True, "Client IP matchede.") return PolicyResult(True, "Client IP matched.")
if event.app != self.app: if event.app == self.app:
return PolicyResult(True, "App matchede.") return PolicyResult(True, "App matched.")
return PolicyResult(False) return PolicyResult(False)
class Meta: class Meta:

View file

@ -10,19 +10,43 @@ from authentik.policies.types import PolicyRequest
class TestEventMatcherPolicy(TestCase): class TestEventMatcherPolicy(TestCase):
"""EventMatcherPolicy tests""" """EventMatcherPolicy tests"""
def test_drop_action(self): def test_match_action(self):
"""Test drop event""" """Test match action"""
event = Event.new(EventAction.LOGIN) event = Event.new(EventAction.LOGIN)
request = PolicyRequest(get_anonymous_user()) request = PolicyRequest(get_anonymous_user())
request.context["event"] = event request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
action=EventAction.LOGIN_FAILED action=EventAction.LOGIN
) )
response = policy.passes(request) response = policy.passes(request)
self.assertFalse(response.passing) self.assertTrue(response.passing)
self.assertTupleEqual(response.messages, ("Action did not match.",)) self.assertTupleEqual(response.messages, ("Action matched.",))
def test_drop_client_ip(self): def test_match_client_ip(self):
"""Test match client_ip"""
event = Event.new(EventAction.LOGIN)
event.client_ip = "1.2.3.4"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.4"
)
response = policy.passes(request)
self.assertTrue(response.passing)
self.assertTupleEqual(response.messages, ("Client IP matched.",))
def test_match_app(self):
"""Test match app"""
event = Event.new(EventAction.LOGIN)
event.app = "foo"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(app="foo")
response = policy.passes(request)
self.assertTrue(response.passing)
self.assertTupleEqual(response.messages, ("App matched.",))
def test_drop(self):
"""Test drop event""" """Test drop event"""
event = Event.new(EventAction.LOGIN) event = Event.new(EventAction.LOGIN)
event.client_ip = "1.2.3.4" event.client_ip = "1.2.3.4"
@ -33,30 +57,6 @@ class TestEventMatcherPolicy(TestCase):
) )
response = policy.passes(request) response = policy.passes(request)
self.assertFalse(response.passing) self.assertFalse(response.passing)
self.assertTupleEqual(response.messages, ("Client IP did not match.",))
def test_drop_app(self):
"""Test drop event"""
event = Event.new(EventAction.LOGIN)
event.app = "foo"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(app="bar")
response = policy.passes(request)
self.assertFalse(response.passing)
self.assertTupleEqual(response.messages, ("App did not match.",))
def test_passing(self):
"""Test passing event"""
event = Event.new(EventAction.LOGIN)
event.client_ip = "1.2.3.4"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.4"
)
response = policy.passes(request)
self.assertTrue(response.passing)
def test_invalid(self): def test_invalid(self):
"""Test passing event""" """Test passing event"""

View file

@ -7623,6 +7623,7 @@ definitions:
created: created:
title: Created title: Created
type: string type: string
format: date-time
readOnly: true readOnly: true
event: event:
$ref: '#/definitions/Event' $ref: '#/definitions/Event'