"""Monitored tasks""" from dataclasses import dataclass, field from datetime import datetime from enum import Enum from timeit import default_timer from typing import Any, Optional from django.core.cache import cache from django.utils.translation import gettext_lazy as _ from structlog.stdlib import get_logger from tenant_schemas_celery.task import TenantTask from authentik.events.apps import GAUGE_TASKS from authentik.events.models import Event, EventAction from authentik.lib.utils.errors import exception_to_string LOGGER = get_logger() CACHE_KEY_PREFIX = "goauthentik.io/events/tasks/" class TaskResultStatus(Enum): """Possible states of tasks""" SUCCESSFUL = 1 WARNING = 2 ERROR = 4 UNKNOWN = 8 @dataclass class TaskResult: """Result of a task run, this class is created by the task itself and used by self.set_status""" status: TaskResultStatus messages: list[str] = field(default_factory=list) # Optional UID used in cache for tasks that run in different instances uid: Optional[str] = field(default=None) def with_error(self, exc: Exception) -> "TaskResult": """Since errors might not always be pickle-able, set the traceback""" # TODO: Mark exception somehow so that is rendered as
 in frontend
        self.messages.append(exception_to_string(exc))
        return self


@dataclass
class TaskInfo:
    """Info about a task run"""

    task_name: str
    start_timestamp: float
    finish_timestamp: float
    finish_time: datetime

    result: TaskResult

    task_call_module: str
    task_call_func: str
    task_call_args: list[Any] = field(default_factory=list)
    task_call_kwargs: dict[str, Any] = field(default_factory=dict)

    task_description: Optional[str] = field(default=None)

    @staticmethod
    def all() -> dict[str, "TaskInfo"]:
        """Get all TaskInfo objects"""
        return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*"))

    @staticmethod
    def by_name(name: str) -> Optional["TaskInfo"] | Optional[list["TaskInfo"]]:
        """Get TaskInfo Object by name"""
        if "*" in name:
            return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values()
        return cache.get(CACHE_KEY_PREFIX + name, None)

    @property
    def full_name(self) -> str:
        """Get the full cache key with task name and UID"""
        key = CACHE_KEY_PREFIX + self.task_name
        if self.result.uid:
            uid_suffix = f":{self.result.uid}"
            key += uid_suffix
            if not self.task_name.endswith(uid_suffix):
                self.task_name += uid_suffix
        return key

    def delete(self):
        """Delete task info from cache"""
        return cache.delete(self.full_name)

    def update_metrics(self):
        """Update prometheus metrics"""
        start = default_timer()
        if hasattr(self, "start_timestamp"):
            start = self.start_timestamp
        try:
            duration = max(self.finish_timestamp - start, 0)
        except TypeError:
            duration = 0
        GAUGE_TASKS.labels(
            task_name=self.task_name.split(":")[0],
            task_uid=self.result.uid or "",
            status=self.result.status.name.lower(),
        ).set(duration)

    def save(self, timeout_hours=6):
        """Save task into cache"""
        self.update_metrics()
        cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60)


class MonitoredTask(TenantTask):
    """Task which can save its state to the cache"""

    # For tasks that should only be listed if they failed, set this to False
    save_on_success: bool

    _result: Optional[TaskResult]

    _uid: Optional[str]
    start: Optional[float] = None

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.save_on_success = True
        self._uid = None
        self._result = None
        self.result_timeout_hours = 6

    def set_uid(self, uid: str):
        """Set UID, so in the case of an unexpected error its saved correctly"""
        self._uid = uid

    def set_status(self, result: TaskResult):
        """Set result for current run, will overwrite previous result."""
        self._result = result

    def before_start(self, task_id, args, kwargs):
        self.start = default_timer()
        return super().before_start(task_id, args, kwargs)

    # pylint: disable=too-many-arguments
    def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
        super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
        if not self._result:
            return
        if not self._result.uid:
            self._result.uid = self._uid
        info = TaskInfo(
            task_name=self.__name__,
            task_description=self.__doc__,
            start_timestamp=self.start or default_timer(),
            finish_timestamp=default_timer(),
            finish_time=datetime.now(),
            result=self._result,
            task_call_module=self.__module__,
            task_call_func=self.__name__,
            task_call_args=args,
            task_call_kwargs=kwargs,
        )
        if self._result.status == TaskResultStatus.SUCCESSFUL and not self.save_on_success:
            info.delete()
            return
        info.save(self.result_timeout_hours)

    # pylint: disable=too-many-arguments
    def on_failure(self, exc, task_id, args, kwargs, einfo):
        super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
        if not self._result:
            self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)])
        if not self._result.uid:
            self._result.uid = self._uid
        TaskInfo(
            task_name=self.__name__,
            task_description=self.__doc__,
            start_timestamp=self.start or default_timer(),
            finish_timestamp=default_timer(),
            finish_time=datetime.now(),
            result=self._result,
            task_call_module=self.__module__,
            task_call_func=self.__name__,
            task_call_args=args,
            task_call_kwargs=kwargs,
        ).save(self.result_timeout_hours)
        Event.new(
            EventAction.SYSTEM_TASK_EXCEPTION,
            message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
        ).save()

    def run(self, *args, **kwargs):
        raise NotImplementedError


def prefill_task(func):
    """Ensure a task's details are always in cache, so it can always be triggered via API"""
    status = TaskInfo.by_name(func.__name__)
    if status:
        return func
    TaskInfo(
        task_name=func.__name__,
        task_description=func.__doc__,
        result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]),
        task_call_module=func.__module__,
        task_call_func=func.__name__,
        # We don't have real values for these attributes but they cannot be null
        start_timestamp=0,
        finish_timestamp=0,
        finish_time=datetime.now(),
    ).save(86400)
    LOGGER.debug("prefilled task", task_name=func.__name__)
    return func