task fixes, creation of tenant now works by cloning a template schema, some other small stuff

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt 2023-11-15 22:03:46 +01:00
parent d11721805a
commit a4fd37e429
No known key found for this signature in database
GPG Key ID: 9C3FA22FABF1AA8D
9 changed files with 3440 additions and 37 deletions

View File

@ -14,24 +14,23 @@ class ManagedAppConfig(AppConfig):
_logger: BoundLogger _logger: BoundLogger
RECONCILE_PREFIX: str = "reconcile_"
RECONCILE_TENANT_PREFIX: str = "reconcile_tenant_"
def __init__(self, app_name: str, *args, **kwargs) -> None: def __init__(self, app_name: str, *args, **kwargs) -> None:
super().__init__(app_name, *args, **kwargs) super().__init__(app_name, *args, **kwargs)
self._logger = get_logger().bind(app_name=app_name) self._logger = get_logger().bind(app_name=app_name)
def ready(self) -> None: def ready(self) -> None:
self.reconcile() self.reconcile()
self.reconcile_tenant()
return super().ready() return super().ready()
def import_module(self, path: str): def import_module(self, path: str):
"""Load module""" """Load module"""
import_module(path) import_module(path)
def reconcile(self) -> None: def _reconcile(self, prefix: str) -> None:
"""reconcile ourselves"""
from authentik.tenants.models import Tenant
prefix = "reconcile_"
tenant_prefix = "reconcile_tenant_"
for meth_name in dir(self): for meth_name in dir(self):
meth = getattr(self, meth_name) meth = getattr(self, meth_name)
if not ismethod(meth): if not ismethod(meth):
@ -39,16 +38,6 @@ class ManagedAppConfig(AppConfig):
if not meth_name.startswith(prefix): if not meth_name.startswith(prefix):
continue continue
name = meth_name.replace(prefix, "") name = meth_name.replace(prefix, "")
tenants = Tenant.objects.filter(ready=True)
if not meth_name.startswith(tenant_prefix):
tenants = Tenant.objects.filter(schema_name=get_public_schema_name())
try:
tenants = list(tenants)
except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.debug("Failed to get tenants to run reconcile", name=name, exc=exc)
continue
for tenant in tenants:
with tenant:
try: try:
self._logger.debug("Starting reconciler", name=name) self._logger.debug("Starting reconciler", name=name)
meth() meth()
@ -56,6 +45,31 @@ class ManagedAppConfig(AppConfig):
except (DatabaseError, ProgrammingError, InternalError) as exc: except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.debug("Failed to run reconcile", name=name, exc=exc) self._logger.debug("Failed to run reconcile", name=name, exc=exc)
def reconcile_tenant(self) -> None:
"""reconcile ourselves for tenanted methods"""
from authentik.tenants.models import Tenant
try:
tenants = list(Tenant.objects.filter(ready=True))
except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.debug("Failed to get tenants to run reconcile", exc=exc)
return
for tenant in tenants:
with tenant:
self._reconcile(self.RECONCILE_TENANT_PREFIX)
def reconcile(self) -> None:
"""reconcile ourselves"""
from authentik.tenants.models import Tenant
try:
default_tenant = Tenant.objects.get(schema_name=get_public_schema_name())
except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.debug("Failed to get default tenant to run reconcile", exc=exc)
return
with default_tenant:
self._reconcile(self.RECONCILE_PREFIX)
class AuthentikBlueprintsConfig(ManagedAppConfig): class AuthentikBlueprintsConfig(ManagedAppConfig):
"""authentik Blueprints app""" """authentik Blueprints app"""

View File

@ -14,7 +14,7 @@ from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_SYSTEM
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
def check_blueprint_v1_file(BlueprintInstance: type, path: Path): def check_blueprint_v1_file(BlueprintInstance: type, db_alias, path: Path):
"""Check if blueprint should be imported""" """Check if blueprint should be imported"""
from authentik.blueprints.models import BlueprintInstanceStatus from authentik.blueprints.models import BlueprintInstanceStatus
from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata
@ -29,7 +29,9 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
if version != 1: if version != 1:
return return
blueprint_file.seek(0) blueprint_file.seek(0)
instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first() instance: BlueprintInstance = (
BlueprintInstance.objects.using(db_alias).filter(path=path).first()
)
rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir"))) rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir")))
meta = None meta = None
if metadata: if metadata:
@ -37,7 +39,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
if meta.labels.get(LABEL_AUTHENTIK_INSTANTIATE, "").lower() == "false": if meta.labels.get(LABEL_AUTHENTIK_INSTANTIATE, "").lower() == "false":
return return
if not instance: if not instance:
instance = BlueprintInstance( BlueprintInstance.objects.using(db_alias).create(
name=meta.name if meta else str(rel_path), name=meta.name if meta else str(rel_path),
path=str(rel_path), path=str(rel_path),
context={}, context={},
@ -47,7 +49,6 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
last_applied_hash="", last_applied_hash="",
metadata=metadata or {}, metadata=metadata or {},
) )
instance.save()
def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
@ -56,7 +57,7 @@ def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True): for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True):
check_blueprint_v1_file(BlueprintInstance, Path(file)) check_blueprint_v1_file(BlueprintInstance, db_alias, Path(file))
for blueprint in BlueprintInstance.objects.using(db_alias).all(): for blueprint in BlueprintInstance.objects.using(db_alias).all():
# If we already have flows (and we should always run before flow migrations) # If we already have flows (and we should always run before flow migrations)

View File

@ -5,10 +5,10 @@ from enum import Enum
from timeit import default_timer from timeit import default_timer
from typing import Any, Optional from typing import Any, Optional
from celery import Task
from django.core.cache import cache from django.core.cache import cache
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from tenant_schemas_celery.task import TenantTask
from authentik.events.apps import GAUGE_TASKS from authentik.events.apps import GAUGE_TASKS
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
@ -112,7 +112,7 @@ class TaskInfo:
cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60) cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60)
class MonitoredTask(Task): class MonitoredTask(TenantTask):
"""Task which can save its state to the cache""" """Task which can save its state to the cache"""
# For tasks that should only be listed if they failed, set this to False # For tasks that should only be listed if they failed, set this to False

View File

@ -117,7 +117,8 @@ def add_process_id(logger: Logger, method_name: str, event_dict):
def add_tenant_information(logger: Logger, method_name: str, event_dict): def add_tenant_information(logger: Logger, method_name: str, event_dict):
event_dict["schema_name"] = connection.tenant.schema_name tenant = getattr(connection, "tenant", None)
event_dict["domain_url"] = getattr(connection.tenant, "domain_url", None) if tenant is not None:
event_dict["schema_name"] = tenant.schema_name
event_dict["domain_url"] = getattr(tenant, "domain_url", None)
return event_dict return event_dict

View File

@ -121,6 +121,9 @@ TENANT_APPS = [
TENANT_MODEL = "authentik_tenants.Tenant" TENANT_MODEL = "authentik_tenants.Tenant"
TENANT_DOMAIN_MODEL = "authentik_tenants.Domain" TENANT_DOMAIN_MODEL = "authentik_tenants.Domain"
TENANT_CREATION_FAKES_MIGRATIONS = True
TENANT_BASE_SCHEMA = "template"
GUARDIAN_MONKEY_PATCH = False GUARDIAN_MONKEY_PATCH = False
SPECTACULAR_SETTINGS = { SPECTACULAR_SETTINGS = {

3300
authentik/tenants/clone.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -143,4 +143,8 @@ class Migration(migrations.Migration):
}, },
), ),
migrations.RunPython(code=create_default_tenant, reverse_code=migrations.RunPython.noop), migrations.RunPython(code=create_default_tenant, reverse_code=migrations.RunPython.noop),
migrations.RunSQL(
sql="CREATE SCHEMA IF NOT EXISTS template;",
reverse_sql="DROP SCHEMA IF EXISTS template;",
),
] ]

View File

@ -1,16 +1,32 @@
"""Tenant models""" """Tenant models"""
from uuid import uuid4 from uuid import uuid4
from django.db import models from django.apps import apps
from django.db.models.deletion import ProtectedError from django.conf import settings
from django.db.models.signals import pre_delete from django.core.management import call_command
from django.db import connections, models
from django.db.models.base import ValidationError
from django.dispatch import receiver from django.dispatch import receiver
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_tenants.models import DomainMixin, TenantMixin, post_schema_sync from django_tenants.models import (
DomainMixin,
TenantMixin,
post_schema_sync,
schema_needs_to_be_sync,
)
from django_tenants.postgresql_backend.base import _check_schema_name
from django_tenants.utils import (
get_creation_fakes_migrations,
get_tenant_base_schema,
get_tenant_database_alias,
schema_exists,
)
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.tenants.clone import CloneSchema
LOGGER = get_logger() LOGGER = get_logger()
@ -57,6 +73,61 @@ class Tenant(TenantMixin, SerializerModel):
default=86400, default=86400,
) )
def create_schema(self, check_if_exists=False, sync_schema=True, verbosity=1):
"""
Creates the schema 'schema_name' for this tenant. Optionally checks if
the schema already exists before creating it. Returns true if the
schema was created, false otherwise.
"""
# safety check
connection = connections[get_tenant_database_alias()]
_check_schema_name(self.schema_name)
cursor = connection.cursor()
if check_if_exists and schema_exists(self.schema_name):
return False
fake_migrations = get_creation_fakes_migrations()
if sync_schema:
if fake_migrations:
# copy tables and data from provided model schema
base_schema = get_tenant_base_schema()
clone_schema = CloneSchema()
clone_schema.clone_schema(base_schema, self.schema_name)
call_command(
"migrate_schemas",
tenant=True,
fake=True,
schema_name=self.schema_name,
interactive=False,
verbosity=verbosity,
)
else:
# create the schema
cursor.execute('CREATE SCHEMA "%s"' % self.schema_name)
call_command(
"migrate_schemas",
tenant=True,
schema_name=self.schema_name,
interactive=False,
verbosity=verbosity,
)
connection.set_schema_to_public()
def save(self, *args, **kwargs):
if self.schema_name == "template":
raise Exception("Cannot create schema named template")
super().save(*args, **kwargs)
def delete(self, *args, **kwargs):
if self.schema_name in ("public", "template"):
raise Exception("Cannot delete schema public or template")
super().delete(*args, **kwargs)
@property @property
def serializer(self) -> Serializer: def serializer(self) -> Serializer:
from authentik.tenants.api import TenantSerializer from authentik.tenants.api import TenantSerializer
@ -64,7 +135,7 @@ class Tenant(TenantMixin, SerializerModel):
return TenantSerializer return TenantSerializer
def __str__(self) -> str: def __str__(self) -> str:
return f"Tenant {self.domain_regex}" return f"Tenant {self.name}"
class Meta: class Meta:
verbose_name = _("Tenant") verbose_name = _("Tenant")
@ -87,6 +158,14 @@ class Domain(DomainMixin, SerializerModel):
@receiver(post_schema_sync, sender=TenantMixin) @receiver(post_schema_sync, sender=TenantMixin)
def tenant_ready(sender, tenant, **kwargs): def tenant_needs_sync(sender, tenant, **kwargs):
if tenant.ready:
return
with tenant:
for app in apps.get_app_configs():
if isinstance(app, ManagedAppConfig):
app._reconcile(ManagedAppConfig.RECONCILE_TENANT_PREFIX)
tenant.ready = True tenant.ready = True
tenant.save() tenant.save()

View File

@ -109,6 +109,7 @@ if __name__ == "__main__":
"available on your PYTHONPATH environment variable? Did you " "available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?" "forget to activate a virtual environment?"
) from exc ) from exc
execute_from_command_line(["", "migrate"]) execute_from_command_line(["", "migrate_schemas"])
execute_from_command_line(["", "migrate_schemas", "--schema", "template", "--tenant"])
finally: finally:
release_lock(curr) release_lock(curr)