diff --git a/authentik/admin/tasks.py b/authentik/admin/tasks.py index 2c7e91810..e5672a132 100644 --- a/authentik/admin/tasks.py +++ b/authentik/admin/tasks.py @@ -11,12 +11,7 @@ from structlog.stdlib import get_logger from authentik import ENV_GIT_HASH_KEY, __version__ from authentik.events.models import Event, EventAction, Notification -from authentik.events.monitored_tasks import ( - MonitoredTask, - TaskResult, - TaskResultStatus, - prefill_task, -) +from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus from authentik.lib.config import CONFIG from authentik.lib.utils.http import get_http_session from authentik.root.celery import CELERY_APP @@ -53,9 +48,8 @@ def clear_update_notifications(): notification.delete() -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def update_latest_version(self: MonitoredTask): +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def update_latest_version(self: PrefilledMonitoredTask): """Update latest version info""" if CONFIG.y_bool("disable_update_check"): cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT) diff --git a/authentik/core/tasks.py b/authentik/core/tasks.py index 3e18c50b4..79e250202 100644 --- a/authentik/core/tasks.py +++ b/authentik/core/tasks.py @@ -16,21 +16,15 @@ from kubernetes.config.incluster_config import SERVICE_HOST_ENV_NAME from structlog.stdlib import get_logger from authentik.core.models import AuthenticatedSession, ExpiringModel -from authentik.events.monitored_tasks import ( - MonitoredTask, - TaskResult, - TaskResultStatus, - prefill_task, -) +from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus from authentik.lib.config import CONFIG from authentik.root.celery import CELERY_APP LOGGER = get_logger() -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def clean_expired_models(self: MonitoredTask): +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def clean_expired_models(self: PrefilledMonitoredTask): """Remove expired objects""" messages = [] for cls in ExpiringModel.__subclasses__(): @@ -68,9 +62,8 @@ def should_backup() -> bool: return True -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def backup_database(self: MonitoredTask): # pragma: no cover +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def backup_database(self: PrefilledMonitoredTask): # pragma: no cover """Database backup""" self.result_timeout_hours = 25 if not should_backup(): diff --git a/authentik/events/monitored_tasks.py b/authentik/events/monitored_tasks.py index 9c71ca6f5..138b4070c 100644 --- a/authentik/events/monitored_tasks.py +++ b/authentik/events/monitored_tasks.py @@ -112,30 +112,6 @@ class TaskInfo: cache.set(key, self, timeout=timeout_hours * 60 * 60) -def prefill_task(func): - """Ensure a task's details are always in cache, so it can always be triggered via API""" - - def wrapper(*args, **kwargs): - status = TaskInfo.by_name(func.__name__) - if status: - return func(*args, **kwargs) - TaskInfo( - task_name=func.__name__, - task_description=func.__doc__, - result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]), - task_call_module=func.__module__, - task_call_func=func.__name__, - # We don't have real values for these attributes but they cannot be null - start_timestamp=default_timer(), - finish_timestamp=default_timer(), - finish_time=datetime.now(), - ).save(86400) - LOGGER.debug("prefilled task", task_name=func.__name__) - return func(*args, **kwargs) - - return wrapper - - class MonitoredTask(Task): """Task which can save its state to the cache""" @@ -210,5 +186,31 @@ class MonitoredTask(Task): raise NotImplementedError +class PrefilledMonitoredTask(MonitoredTask): + """Subclass of MonitoredTask, but create entry in cache if task hasn't been run + Does not support UID""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + status = TaskInfo.by_name(self.__name__) + if status: + return + TaskInfo( + task_name=self.__name__, + task_description=self.__doc__, + result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]), + task_call_module=self.__module__, + task_call_func=self.__name__, + # We don't have real values for these attributes but they cannot be null + start_timestamp=default_timer(), + finish_timestamp=default_timer(), + finish_time=datetime.now(), + ).save(86400) + LOGGER.debug("prefilled task", task_name=self.__name__) + + def run(self, *args, **kwargs): + raise NotImplementedError + + for task in TaskInfo.all().values(): task.set_prom_metrics() diff --git a/authentik/managed/tasks.py b/authentik/managed/tasks.py index 2cc9b21d2..118b9c370 100644 --- a/authentik/managed/tasks.py +++ b/authentik/managed/tasks.py @@ -2,18 +2,12 @@ from django.db import DatabaseError from authentik.core.tasks import CELERY_APP -from authentik.events.monitored_tasks import ( - MonitoredTask, - TaskResult, - TaskResultStatus, - prefill_task, -) +from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus from authentik.managed.manager import ObjectManager -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def managed_reconcile(self: MonitoredTask): +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def managed_reconcile(self: PrefilledMonitoredTask): """Run ObjectManager to ensure objects are up-to-date""" try: ObjectManager().run() diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 737b1ef15..820f585f6 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -19,9 +19,9 @@ from structlog.stdlib import get_logger from authentik.events.monitored_tasks import ( MonitoredTask, + PrefilledMonitoredTask, TaskResult, TaskResultStatus, - prefill_task, ) from authentik.lib.utils.reflection import path_to_class from authentik.outposts.controllers.base import BaseController, ControllerException @@ -75,9 +75,8 @@ def outpost_service_connection_state(connection_pk: Any): cache.set(connection.state_key, state, timeout=None) -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def outpost_service_connection_monitor(self: MonitoredTask): +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def outpost_service_connection_monitor(self: PrefilledMonitoredTask): """Regularly check the state of Outpost Service Connections""" connections = OutpostServiceConnection.objects.all() for connection in connections.iterator(): @@ -125,9 +124,8 @@ def outpost_controller( self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, logs)) -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def outpost_token_ensurer(self: MonitoredTask): +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def outpost_token_ensurer(self: PrefilledMonitoredTask): """Periodically ensure that all Outposts have valid Service Accounts and Tokens""" all_outposts = Outpost.objects.all() diff --git a/authentik/policies/reputation/tasks.py b/authentik/policies/reputation/tasks.py index 3126dfca8..49b1590d1 100644 --- a/authentik/policies/reputation/tasks.py +++ b/authentik/policies/reputation/tasks.py @@ -2,12 +2,7 @@ from django.core.cache import cache from structlog.stdlib import get_logger -from authentik.events.monitored_tasks import ( - MonitoredTask, - TaskResult, - TaskResultStatus, - prefill_task, -) +from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus from authentik.policies.reputation.models import IPReputation, UserReputation from authentik.policies.reputation.signals import CACHE_KEY_IP_PREFIX, CACHE_KEY_USER_PREFIX from authentik.root.celery import CELERY_APP @@ -15,9 +10,8 @@ from authentik.root.celery import CELERY_APP LOGGER = get_logger() -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def save_ip_reputation(self: MonitoredTask): +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def save_ip_reputation(self: PrefilledMonitoredTask): """Save currently cached reputation to database""" objects_to_update = [] for key, score in cache.get_many(cache.keys(CACHE_KEY_IP_PREFIX + "*")).items(): @@ -29,9 +23,8 @@ def save_ip_reputation(self: MonitoredTask): self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated IP Reputation"])) -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def save_user_reputation(self: MonitoredTask): +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def save_user_reputation(self: PrefilledMonitoredTask): """Save currently cached reputation to database""" objects_to_update = [] for key, score in cache.get_many(cache.keys(CACHE_KEY_USER_PREFIX + "*")).items(): diff --git a/authentik/sources/saml/tasks.py b/authentik/sources/saml/tasks.py index cb72d55d0..fb85b3088 100644 --- a/authentik/sources/saml/tasks.py +++ b/authentik/sources/saml/tasks.py @@ -3,12 +3,7 @@ from django.utils.timezone import now from structlog.stdlib import get_logger from authentik.core.models import AuthenticatedSession, User -from authentik.events.monitored_tasks import ( - MonitoredTask, - TaskResult, - TaskResultStatus, - prefill_task, -) +from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus from authentik.lib.utils.time import timedelta_from_string from authentik.root.celery import CELERY_APP from authentik.sources.saml.models import SAMLSource @@ -16,9 +11,8 @@ from authentik.sources.saml.models import SAMLSource LOGGER = get_logger() -@CELERY_APP.task(bind=True, base=MonitoredTask) -@prefill_task -def clean_temporary_users(self: MonitoredTask): +@CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) +def clean_temporary_users(self: PrefilledMonitoredTask): """Remove temporary users created by SAML Sources""" _now = now() messages = []