diff --git a/authentik/blueprints/apps.py b/authentik/blueprints/apps.py index 90df91c00..e63967e16 100644 --- a/authentik/blueprints/apps.py +++ b/authentik/blueprints/apps.py @@ -5,6 +5,7 @@ from inspect import ismethod from django.apps import AppConfig from django.db import DatabaseError, InternalError, ProgrammingError +from django_tenants.utils import get_public_schema_name from structlog.stdlib import BoundLogger, get_logger @@ -27,7 +28,10 @@ class ManagedAppConfig(AppConfig): def reconcile(self) -> None: """reconcile ourselves""" + from authentik.tenants.models import Tenant + prefix = "reconcile_" + tenant_prefix = "reconcile_tenant_" for meth_name in dir(self): meth = getattr(self, meth_name) if not ismethod(meth): @@ -35,12 +39,17 @@ class ManagedAppConfig(AppConfig): if not meth_name.startswith(prefix): continue name = meth_name.replace(prefix, "") - try: - self._logger.debug("Starting reconciler", name=name) - meth() - self._logger.debug("Successfully reconciled", name=name) - except (DatabaseError, ProgrammingError, InternalError) as exc: - self._logger.debug("Failed to run reconcile", name=name, exc=exc) + tenants = Tenant.objects.all() + if meth_name.startswith(tenant_prefix): + tenants = Tenant.objects.get(schema_name=get_public_schema_name()) + for tenant in tenants: + with tenant: + try: + self._logger.debug("Starting reconciler", name=name) + meth() + self._logger.debug("Successfully reconciled", name=name) + except (DatabaseError, ProgrammingError, InternalError) as exc: + self._logger.debug("Failed to run reconcile", name=name, exc=exc) class AuthentikBlueprintsConfig(ManagedAppConfig): @@ -55,7 +64,7 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): """Load v1 tasks""" self.import_module("authentik.blueprints.v1.tasks") - def reconcile_blueprints_discovery(self): + def reconcile_tenant_blueprints_discovery(self): """Run blueprint discovery""" from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints diff --git a/authentik/blueprints/v1/tasks.py b/authentik/blueprints/v1/tasks.py index 8ff86c996..aad3a1ca8 100644 --- a/authentik/blueprints/v1/tasks.py +++ b/authentik/blueprints/v1/tasks.py @@ -38,6 +38,7 @@ from authentik.events.monitored_tasks import ( from authentik.events.utils import sanitize_dict from authentik.lib.config import CONFIG from authentik.root.celery import CELERY_APP +from authentik.tenants.models import Tenant LOGGER = get_logger() _file_watcher_started = False @@ -75,16 +76,17 @@ class BlueprintEventHandler(FileSystemEventHandler): return if event.is_directory: return - if isinstance(event, FileCreatedEvent): - LOGGER.debug("new blueprint file created, starting discovery") - blueprints_discovery.delay() - if isinstance(event, FileModifiedEvent): - path = Path(event.src_path) - root = Path(CONFIG.get("blueprints_dir")).absolute() - rel_path = str(path.relative_to(root)) - for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): - LOGGER.debug("modified blueprint file, starting apply", instance=instance) - apply_blueprint.delay(instance.pk.hex) + for tenant in Tenant.objects.all(): + with tenant: + if isinstance(event, FileCreatedEvent): + LOGGER.debug("new blueprint file created, starting discovery") + blueprints_discovery.delay() + if isinstance(event, FileModifiedEvent): + path = Path(event.src_path) + root = Path(CONFIG.get("blueprints_dir")).absolute() + rel_path = str(path.relative_to(root)) + for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): + LOGGER.debug("modified blueprint file, starting apply", instance=instance) @CELERY_APP.task( diff --git a/authentik/core/apps.py b/authentik/core/apps.py index 719b2abb1..c3f8c8686 100644 --- a/authentik/core/apps.py +++ b/authentik/core/apps.py @@ -24,7 +24,7 @@ class AuthentikCoreConfig(ManagedAppConfig): worker_ready_hook() - def reconcile_source_inbuilt(self): + def reconcile_tenant_source_inbuilt(self): """Reconcile inbuilt source""" from authentik.core.models import Source diff --git a/authentik/crypto/apps.py b/authentik/crypto/apps.py index bed1ab811..0a64c736d 100644 --- a/authentik/crypto/apps.py +++ b/authentik/crypto/apps.py @@ -39,7 +39,7 @@ class AuthentikCryptoConfig(ManagedAppConfig): }, ) - def reconcile_managed_jwt_cert(self): + def reconcile_tenant_managed_jwt_cert(self): """Ensure managed JWT certificate""" from authentik.crypto.models import CertificateKeyPair @@ -52,7 +52,7 @@ class AuthentikCryptoConfig(ManagedAppConfig): ): self._create_update_cert() - def reconcile_self_signed(self): + def reconcile_tenant_self_signed(self): """Create self-signed keypair""" from authentik.crypto.builder import CertificateBuilder from authentik.crypto.models import CertificateKeyPair diff --git a/authentik/outposts/apps.py b/authentik/outposts/apps.py index 6898a170a..b4302ba71 100644 --- a/authentik/outposts/apps.py +++ b/authentik/outposts/apps.py @@ -29,7 +29,7 @@ class AuthentikOutpostConfig(ManagedAppConfig): """Load outposts signals""" self.import_module("authentik.outposts.signals") - def reconcile_embedded_outpost(self): + def reconcile_tenant_embedded_outpost(self): """Ensure embedded outpost""" from authentik.outposts.models import ( DockerServiceConnection, diff --git a/authentik/root/celery.py b/authentik/root/celery.py index 2747bae45..e49d88a8d 100644 --- a/authentik/root/celery.py +++ b/authentik/root/celery.py @@ -6,7 +6,7 @@ from pathlib import Path from tempfile import gettempdir from typing import Callable -from celery import Celery, bootsteps +from celery import bootsteps from celery.apps.worker import Worker from celery.signals import ( after_task_publish, @@ -19,8 +19,10 @@ from celery.signals import ( ) from django.conf import settings from django.db import ProgrammingError +from django_tenants.utils import get_public_schema_name from structlog.contextvars import STRUCTLOG_KEY_PREFIX from structlog.stdlib import get_logger +from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp from authentik.lib.sentry import before_send from authentik.lib.utils.errors import exception_to_string @@ -29,7 +31,7 @@ from authentik.lib.utils.errors import exception_to_string os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") LOGGER = get_logger() -CELERY_APP = Celery("authentik") +CELERY_APP = TenantAwareCeleryApp("authentik") CTX_TASK_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "task_id", default=Ellipsis) HEARTBEAT_FILE = Path(gettempdir() + "/authentik-worker") @@ -80,8 +82,13 @@ def task_error_hook(task_id, exception: Exception, traceback, *args, **kwargs): Event.new(EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception)).save() -def _get_startup_tasks() -> list[Callable]: - """Get all tasks to be run on startup""" +def _get_startup_tasks_default_tenant() -> list[Callable]: + """Get all tasks to be run on startup for the default tenant""" + return [] + + +def _get_startup_tasks_all_tenants() -> list[Callable]: + """Get all tasks to be run on startup for all tenants""" from authentik.admin.tasks import clear_update_notifications from authentik.outposts.tasks import outpost_connection_discovery, outpost_controller_all from authentik.providers.proxy.tasks import proxy_set_defaults @@ -97,13 +104,25 @@ def _get_startup_tasks() -> list[Callable]: @worker_ready.connect def worker_ready_hook(*args, **kwargs): """Run certain tasks on worker start""" + from authentik.tenants.models import Tenant LOGGER.info("Dispatching startup tasks...") - for task in _get_startup_tasks(): + + def _run_task(task: Callable): try: task.delay() except ProgrammingError as exc: LOGGER.warning("Startup task failed", task=task, exc=exc) + + for task in _get_startup_tasks_default_tenant(): + with Tenant.objects.get(schema_name=get_public_schema_name()): + _run_task(task) + + for task in _get_startup_tasks_all_tenants(): + for tenant in Tenant.objects.all(): + with tenant: + _run_task(task) + from authentik.blueprints.v1.tasks import start_blueprint_watcher start_blueprint_watcher() diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 85859ac2e..1cd85df6e 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -207,6 +207,8 @@ CACHES = { "TIMEOUT": CONFIG.get_int("cache.timeout", 300), "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, "KEY_PREFIX": "authentik_cache", + "KEY_FUNCTION": "django_tenants.cache.make_key", + "REVERSE_KEY_FUNCTION": "django_tenants.cache.reverse_key", } } DJANGO_REDIS_SCAN_ITERSIZE = 1000 @@ -360,6 +362,7 @@ CELERY = { "options": {"queue": "authentik_scheduled"}, }, }, + "beat_scheduler": "authentik.tenants.scheduler:TenantAwarePersistentScheduler", "task_create_missing_queues": True, "task_default_queue": "authentik", "broker_url": CONFIG.get("broker.url") diff --git a/authentik/tenants/models.py b/authentik/tenants/models.py index c8a135522..aa36f75c1 100644 --- a/authentik/tenants/models.py +++ b/authentik/tenants/models.py @@ -6,7 +6,7 @@ from django.db.models.deletion import ProtectedError from django.db.models.signals import pre_delete from django.dispatch import receiver from django.utils.translation import gettext_lazy as _ -from django_tenants.models import DomainMixin, TenantMixin +from django_tenants.models import DomainMixin, TenantMixin, post_schema_sync from rest_framework.serializers import Serializer from structlog.stdlib import get_logger @@ -23,6 +23,7 @@ class Tenant(TenantMixin, SerializerModel): auto_create_schema = True auto_drop_schema = True + ready = models.BooleanField(default=False) avatars = models.TextField( help_text=_("Configure how authentik should show avatars for users."), @@ -83,3 +84,9 @@ class Domain(DomainMixin, SerializerModel): class Meta: verbose_name = _("Domain") verbose_name_plural = _("Domains") + + +@receiver(post_schema_sync, sender=TenantMixin) +def tenant_ready(sender, tenant, **kwargs): + tenant.ready = True + tenant.save() diff --git a/authentik/tenants/scheduler.py b/authentik/tenants/scheduler.py new file mode 100644 index 000000000..02a2aabf6 --- /dev/null +++ b/authentik/tenants/scheduler.py @@ -0,0 +1,9 @@ +from tenant_schemas_celery.scheduler import ( + TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler, +) + + +class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler): + @classmethod + def get_queryset(cls): + return super().get_queryset().filter(ready=True) diff --git a/poetry.lock b/poetry.lock index f130d52be..d2365a294 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3820,6 +3820,20 @@ jsonschema = "*" pyyaml = "*" typing-extensions = "*" +[[package]] +name = "tenant-schemas-celery" +version = "2.2.0" +description = "Celery integration for django-tenant-schemas and django-tenants" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tenant-schemas-celery-2.2.0.tar.gz", hash = "sha256:b4fc16959cb98597591afb30f07256f70d8470d97c22c62e3d3af344868cdd6f"}, +] + +[package.dependencies] +celery = "*" + [[package]] name = "tomlkit" version = "0.12.3" @@ -4648,4 +4662,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "f6f26ceab7fcd5b1d614069cd5423aa8e90d3892d03f302907129fbff3ff63a0" +content-hash = "7c8fba15c50cc7ff0341b6654444d49d69e4e99659ea1cccc411bef9efdcef82" diff --git a/pyproject.toml b/pyproject.toml index 6c7d9865c..e80952066 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,6 +133,7 @@ django-guardian = "*" django-model-utils = "*" django-prometheus = "*" django-redis = "*" +# See https://github.com/django-tenants/django-tenants/pull/959 django-tenants = { git = "https://github.com/hho6643/django-tenants.git", branch="hho6643-psycopg3_fixes" } djangorestframework = "*" djangorestframework-guardian = "*" @@ -162,6 +163,7 @@ sentry-sdk = "*" service_identity = "*" structlog = "*" swagger-spec-validator = "*" +tenant-schemas-celery = "*" twilio = "*" twisted = "*" ua-parser = "*"