fix tests
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
4a6f956c28
commit
c3e9d23190
|
@ -7,8 +7,6 @@ from django.urls import reverse
|
|||
from authentik import __version__
|
||||
from authentik.blueprints.tests import reconcile_app
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tasks import clean_expired_models
|
||||
from authentik.events.monitored_tasks import TaskStatus
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
|
@ -23,53 +21,6 @@ class TestAdminAPI(TestCase):
|
|||
self.group.save()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_tasks(self):
|
||||
"""Test Task API"""
|
||||
clean_expired_models.delay()
|
||||
response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body))
|
||||
|
||||
def test_tasks_single(self):
|
||||
"""Test Task API (read single)"""
|
||||
clean_expired_models.delay()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:admin_system_tasks-detail",
|
||||
kwargs={"pk": "clean_expired_models"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.name)
|
||||
self.assertEqual(body["task_name"], "clean_expired_models")
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"})
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_tasks_retry(self):
|
||||
"""Test Task API (retry)"""
|
||||
clean_expired_models.delay()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:admin_system_tasks-retry",
|
||||
kwargs={"pk": "clean_expired_models"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_tasks_retry_404(self):
|
||||
"""Test Task API (retry, 404)"""
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:admin_system_tasks-retry",
|
||||
kwargs={"pk": "qwerqewrqrqewrqewr"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_version(self):
|
||||
"""Test Version API"""
|
||||
response = self.client.get(reverse("authentik_api:admin_version"))
|
||||
|
|
|
@ -1,15 +1,25 @@
|
|||
"""Test Monitored tasks"""
|
||||
from django.test import TestCase
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tasks import clean_expired_models
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import SystemTask, TaskStatus
|
||||
from authentik.events.monitored_tasks import MonitoredTask
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
||||
class TestMonitoredTasks(TestCase):
|
||||
class TestSystemTasks(APITestCase):
|
||||
"""Test Monitored tasks"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_failed_successful_remove_state(self):
|
||||
"""Test that a task with `save_on_success` set to `False` that failed saves
|
||||
a state, and upon successful completion will delete the state"""
|
||||
|
@ -28,15 +38,64 @@ class TestMonitoredTasks(TestCase):
|
|||
# First test successful run
|
||||
should_fail = False
|
||||
test_task.delay().get()
|
||||
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid))
|
||||
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid).first())
|
||||
|
||||
# Then test failed
|
||||
should_fail = True
|
||||
test_task.delay().get()
|
||||
info = SystemTask.objects.filter(name="test_task", uid=uid)
|
||||
self.assertEqual(info.status, TaskStatus.ERROR)
|
||||
task = SystemTask.objects.filter(name="test_task", uid=uid).first()
|
||||
self.assertEqual(task.status, TaskStatus.ERROR)
|
||||
|
||||
# Then after that, the state should be removed
|
||||
should_fail = False
|
||||
test_task.delay().get()
|
||||
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid))
|
||||
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid).first())
|
||||
|
||||
def test_tasks(self):
|
||||
"""Test Task API"""
|
||||
clean_expired_models.delay().get()
|
||||
response = self.client.get(reverse("authentik_api:systemtask-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"]))
|
||||
|
||||
def test_tasks_single(self):
|
||||
"""Test Task API (read single)"""
|
||||
clean_expired_models.delay().get()
|
||||
task = SystemTask.objects.filter(name="clean_expired_models").first()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:systemtask-detail",
|
||||
kwargs={"pk": str(task.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value)
|
||||
self.assertEqual(body["name"], "clean_expired_models")
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"})
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_tasks_run(self):
|
||||
"""Test Task API (run)"""
|
||||
clean_expired_models.delay().get()
|
||||
task = SystemTask.objects.filter(name="clean_expired_models").first()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:systemtask-run",
|
||||
kwargs={"pk": str(task.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_tasks_run_404(self):
|
||||
"""Test Task API (run, 404)"""
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:systemtask-run",
|
||||
kwargs={"pk": "qwerqewrqrqewrqewr"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
|
|
@ -34,7 +34,7 @@ CACHE_KEY_STATUS = "goauthentik.io/sources/ldap/status/"
|
|||
def ldap_sync_all():
|
||||
"""Sync all sources"""
|
||||
for source in LDAPSource.objects.filter(enabled=True):
|
||||
ldap_sync_single.apply_async(args=[source.pk])
|
||||
ldap_sync_single.apply_async(args=[str(source.pk)])
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
|
@ -95,7 +95,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
|||
for page in sync_inst.get_objects():
|
||||
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
|
||||
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
|
||||
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
|
||||
page_sync = ldap_sync.si(str(source.pk), class_to_path(sync), page_cache_key)
|
||||
signatures.append(page_sync)
|
||||
return signatures
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class LDAPSyncTests(TestCase):
|
|||
"""Test sync with missing page"""
|
||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
ldap_sync.delay(self.source.pk, class_to_path(UserLDAPSynchronizer), "foo").get()
|
||||
ldap_sync.delay(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo").get()
|
||||
task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first()
|
||||
self.assertEqual(task.status, TaskStatus.ERROR)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ def send_mails(stage: EmailStage, *messages: list[EmailMultiAlternatives]):
|
|||
"""Wrapper to convert EmailMessage to dict and send it from worker"""
|
||||
tasks = []
|
||||
for message in messages:
|
||||
tasks.append(send_mail.s(message.__dict__, stage.pk))
|
||||
tasks.append(send_mail.s(message.__dict__, str(stage.pk)))
|
||||
lazy_group = group(*tasks)
|
||||
promise = lazy_group()
|
||||
return promise
|
||||
|
@ -46,7 +46,7 @@ def get_email_body(email: EmailMultiAlternatives) -> str:
|
|||
retry_backoff=True,
|
||||
base=MonitoredTask,
|
||||
)
|
||||
def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[int] = None):
|
||||
def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[str] = None):
|
||||
"""Send Email for Email Stage. Retries are scheduled automatically."""
|
||||
self.save_on_success = False
|
||||
message_id = make_msgid(domain=DNS_NAME)
|
||||
|
|
Reference in a new issue