a4fd37e429
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
215 lines
7 KiB
Python
215 lines
7 KiB
Python
"""Monitored tasks"""
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from timeit import default_timer
|
|
from typing import Any, Optional
|
|
|
|
from django.core.cache import cache
|
|
from django.utils.translation import gettext_lazy as _
|
|
from structlog.stdlib import get_logger
|
|
from tenant_schemas_celery.task import TenantTask
|
|
|
|
from authentik.events.apps import GAUGE_TASKS
|
|
from authentik.events.models import Event, EventAction
|
|
from authentik.lib.utils.errors import exception_to_string
|
|
|
|
LOGGER = get_logger()
|
|
CACHE_KEY_PREFIX = "goauthentik.io/events/tasks/"
|
|
|
|
|
|
class TaskResultStatus(Enum):
|
|
"""Possible states of tasks"""
|
|
|
|
SUCCESSFUL = 1
|
|
WARNING = 2
|
|
ERROR = 4
|
|
UNKNOWN = 8
|
|
|
|
|
|
@dataclass
|
|
class TaskResult:
|
|
"""Result of a task run, this class is created by the task itself
|
|
and used by self.set_status"""
|
|
|
|
status: TaskResultStatus
|
|
|
|
messages: list[str] = field(default_factory=list)
|
|
|
|
# Optional UID used in cache for tasks that run in different instances
|
|
uid: Optional[str] = field(default=None)
|
|
|
|
def with_error(self, exc: Exception) -> "TaskResult":
|
|
"""Since errors might not always be pickle-able, set the traceback"""
|
|
# TODO: Mark exception somehow so that is rendered as <pre> in frontend
|
|
self.messages.append(exception_to_string(exc))
|
|
return self
|
|
|
|
|
|
@dataclass
|
|
class TaskInfo:
|
|
"""Info about a task run"""
|
|
|
|
task_name: str
|
|
start_timestamp: float
|
|
finish_timestamp: float
|
|
finish_time: datetime
|
|
|
|
result: TaskResult
|
|
|
|
task_call_module: str
|
|
task_call_func: str
|
|
task_call_args: list[Any] = field(default_factory=list)
|
|
task_call_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
|
|
task_description: Optional[str] = field(default=None)
|
|
|
|
@staticmethod
|
|
def all() -> dict[str, "TaskInfo"]:
|
|
"""Get all TaskInfo objects"""
|
|
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*"))
|
|
|
|
@staticmethod
|
|
def by_name(name: str) -> Optional["TaskInfo"] | Optional[list["TaskInfo"]]:
|
|
"""Get TaskInfo Object by name"""
|
|
if "*" in name:
|
|
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values()
|
|
return cache.get(CACHE_KEY_PREFIX + name, None)
|
|
|
|
@property
|
|
def full_name(self) -> str:
|
|
"""Get the full cache key with task name and UID"""
|
|
key = CACHE_KEY_PREFIX + self.task_name
|
|
if self.result.uid:
|
|
uid_suffix = f":{self.result.uid}"
|
|
key += uid_suffix
|
|
if not self.task_name.endswith(uid_suffix):
|
|
self.task_name += uid_suffix
|
|
return key
|
|
|
|
def delete(self):
|
|
"""Delete task info from cache"""
|
|
return cache.delete(self.full_name)
|
|
|
|
def update_metrics(self):
|
|
"""Update prometheus metrics"""
|
|
start = default_timer()
|
|
if hasattr(self, "start_timestamp"):
|
|
start = self.start_timestamp
|
|
try:
|
|
duration = max(self.finish_timestamp - start, 0)
|
|
except TypeError:
|
|
duration = 0
|
|
GAUGE_TASKS.labels(
|
|
task_name=self.task_name.split(":")[0],
|
|
task_uid=self.result.uid or "",
|
|
status=self.result.status.name.lower(),
|
|
).set(duration)
|
|
|
|
def save(self, timeout_hours=6):
|
|
"""Save task into cache"""
|
|
self.update_metrics()
|
|
cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60)
|
|
|
|
|
|
class MonitoredTask(TenantTask):
|
|
"""Task which can save its state to the cache"""
|
|
|
|
# For tasks that should only be listed if they failed, set this to False
|
|
save_on_success: bool
|
|
|
|
_result: Optional[TaskResult]
|
|
|
|
_uid: Optional[str]
|
|
start: Optional[float] = None
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.save_on_success = True
|
|
self._uid = None
|
|
self._result = None
|
|
self.result_timeout_hours = 6
|
|
|
|
def set_uid(self, uid: str):
|
|
"""Set UID, so in the case of an unexpected error its saved correctly"""
|
|
self._uid = uid
|
|
|
|
def set_status(self, result: TaskResult):
|
|
"""Set result for current run, will overwrite previous result."""
|
|
self._result = result
|
|
|
|
def before_start(self, task_id, args, kwargs):
|
|
self.start = default_timer()
|
|
return super().before_start(task_id, args, kwargs)
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
|
|
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
|
|
if not self._result:
|
|
return
|
|
if not self._result.uid:
|
|
self._result.uid = self._uid
|
|
info = TaskInfo(
|
|
task_name=self.__name__,
|
|
task_description=self.__doc__,
|
|
start_timestamp=self.start or default_timer(),
|
|
finish_timestamp=default_timer(),
|
|
finish_time=datetime.now(),
|
|
result=self._result,
|
|
task_call_module=self.__module__,
|
|
task_call_func=self.__name__,
|
|
task_call_args=args,
|
|
task_call_kwargs=kwargs,
|
|
)
|
|
if self._result.status == TaskResultStatus.SUCCESSFUL and not self.save_on_success:
|
|
info.delete()
|
|
return
|
|
info.save(self.result_timeout_hours)
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
|
super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
|
|
if not self._result:
|
|
self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)])
|
|
if not self._result.uid:
|
|
self._result.uid = self._uid
|
|
TaskInfo(
|
|
task_name=self.__name__,
|
|
task_description=self.__doc__,
|
|
start_timestamp=self.start or default_timer(),
|
|
finish_timestamp=default_timer(),
|
|
finish_time=datetime.now(),
|
|
result=self._result,
|
|
task_call_module=self.__module__,
|
|
task_call_func=self.__name__,
|
|
task_call_args=args,
|
|
task_call_kwargs=kwargs,
|
|
).save(self.result_timeout_hours)
|
|
Event.new(
|
|
EventAction.SYSTEM_TASK_EXCEPTION,
|
|
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
|
|
).save()
|
|
|
|
def run(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
def prefill_task(func):
|
|
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
|
|
status = TaskInfo.by_name(func.__name__)
|
|
if status:
|
|
return func
|
|
TaskInfo(
|
|
task_name=func.__name__,
|
|
task_description=func.__doc__,
|
|
result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]),
|
|
task_call_module=func.__module__,
|
|
task_call_func=func.__name__,
|
|
# We don't have real values for these attributes but they cannot be null
|
|
start_timestamp=0,
|
|
finish_timestamp=0,
|
|
finish_time=datetime.now(),
|
|
).save(86400)
|
|
LOGGER.debug("prefilled task", task_name=func.__name__)
|
|
return func
|