blueprints: fix issue in prod setups with encoding dataclasses via celery

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-08-16 20:59:36 +02:00
parent aea0958f3f
commit ff788edd9b
2 changed files with 19 additions and 9 deletions

View file

@ -1,6 +1,4 @@
"""Serializer mixin for managed models"""
from dataclasses import asdict
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action
from rest_framework.fields import CharField, DateTimeField, JSONField
@ -86,7 +84,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
def available(self, request: Request) -> Response:
"""Get blueprints"""
files: list[BlueprintFile] = blueprints_find.delay().get()
return Response([sanitize_dict(asdict(file)) for file in files])
return Response([sanitize_dict(file) for file in files])
@permission_required("authentik_blueprints.view_blueprintinstance")
@extend_schema(

View file

@ -21,6 +21,7 @@ from authentik.events.monitored_tasks import (
TaskResultStatus,
prefill_task,
)
from authentik.events.utils import sanitize_dict
from authentik.lib.config import CONFIG
from authentik.root.celery import CELERY_APP
@ -35,6 +36,17 @@ class BlueprintFile:
last_m: int
meta: Optional[BlueprintMetadata] = field(default=None)
@staticmethod
def from_raw(raw: dict, path: Path) -> "BlueprintFile":
"""Create blueprint file from raw YAML"""
root = Path(CONFIG.y("blueprints_dir"))
metadata = raw.get("metadata", None)
version = raw.get("version", 1)
file_hash = sha512(path.read_bytes()).hexdigest()
blueprint = BlueprintFile(path.relative_to(root), version, file_hash, path.stat().st_mtime)
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
return blueprint
@CELERY_APP.task(
throws=(DatabaseError, ProgrammingError, InternalError),
@ -52,14 +64,11 @@ def blueprints_find():
raw_blueprint = None
if not raw_blueprint:
continue
metadata = raw_blueprint.get("metadata", None)
version = raw_blueprint.get("version", 1)
if version != 1:
continue
file_hash = sha512(path.read_bytes()).hexdigest()
blueprint = BlueprintFile(path.relative_to(root), version, file_hash, path.stat().st_mtime)
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
blueprints.append(blueprint)
blueprint = BlueprintFile.from_raw(raw_blueprint, path)
blueprints.append(sanitize_dict(asdict(blueprint)))
return blueprints
@ -70,8 +79,11 @@ def blueprints_find():
def blueprints_discover(self: MonitoredTask):
"""Find blueprints and check if they need to be created in the database"""
count = 0
root = Path(CONFIG.y("blueprints_dir"))
for blueprint in blueprints_find():
check_blueprint_v1_file(blueprint)
check_blueprint_v1_file(
BlueprintFile.from_raw(blueprint, Path(root, blueprint.get("path")))
)
count += 1
self.set_status(
TaskResult(