add importer wrapper that supports multiple yaml documents

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-12-22 21:39:25 +01:00 committed by Jens Langhammer
parent 47e663f48c
commit 6b78190093
No known key found for this signature in database
3 changed files with 57 additions and 20 deletions

View file

@ -130,7 +130,7 @@ class TestBlueprintsV1(TransactionTestCase):
ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete() ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete()
Group.objects.filter(name="test").delete() Group.objects.filter(name="test").delete()
environ["foo"] = generate_id() environ["foo"] = generate_id()
importer = Importer(load_yaml_fixture("fixtures/tags.yaml"), {"bar": "baz"}) importer = Importer(load_yaml_fixture("fixtures/tags.yaml"), context={"bar": "baz"})
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first() policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first()

View file

@ -1,6 +1,7 @@
"""Blueprint importer""" """Blueprint importer"""
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict
from typing import Any, Optional from typing import Any, Optional
from dacite.config import Config from dacite.config import Config
@ -17,7 +18,7 @@ from rest_framework.serializers import BaseSerializer, Serializer
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from structlog.testing import capture_logs from structlog.testing import capture_logs
from structlog.types import EventDict from structlog.types import EventDict
from yaml import load from yaml import load_all
from authentik.blueprints.v1.common import ( from authentik.blueprints.v1.common import (
Blueprint, Blueprint,
@ -77,31 +78,31 @@ def transaction_rollback():
atomic.__exit__(IntegrityError, None, None) atomic.__exit__(IntegrityError, None, None)
class Importer: class SingleDocumentImporter:
"""Import Blueprint from YAML""" """Import Blueprint from YAML"""
logger: BoundLogger logger: BoundLogger
__import: Blueprint
def __init__(self, yaml_input: str, context: Optional[dict] = None): def __init__(self, raw_blueprint: dict, context: Optional[dict] = None):
self.__pk_map: dict[Any, Model] = {} self.__pk_map: dict[Any, Model] = {}
self.logger = get_logger() self.logger = get_logger()
import_dict = load(yaml_input, BlueprintLoader)
try: try:
self.__import = from_dict( self._import = from_dict(
Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState]) Blueprint, raw_blueprint, config=Config(cast=[BlueprintEntryDesiredState])
) )
except DaciteError as exc: except DaciteError as exc:
raise EntryInvalidError from exc raise EntryInvalidError from exc
ctx = {} ctx = {}
always_merger.merge(ctx, self.__import.context) always_merger.merge(ctx, self._import.context)
if context: if context:
always_merger.merge(ctx, context) always_merger.merge(ctx, context)
self.__import.context = ctx self._import.context = ctx
@property @property
def blueprint(self) -> Blueprint: def blueprint(self) -> Blueprint:
"""Get imported blueprint""" """Get imported blueprint"""
return self.__import return self._import
def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]: def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Replace any value if it is a known primary key of an other object""" """Replace any value if it is a known primary key of an other object"""
@ -147,7 +148,7 @@ class Importer:
# pylint: disable-msg=too-many-locals # pylint: disable-msg=too-many-locals
def _validate_single(self, entry: BlueprintEntry) -> Optional[BaseSerializer]: def _validate_single(self, entry: BlueprintEntry) -> Optional[BaseSerializer]:
"""Validate a single entry""" """Validate a single entry"""
if not entry.check_all_conditions_match(self.__import): if not entry.check_all_conditions_match(self._import):
self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") self.logger.debug("One or more conditions of this entry are not fulfilled, skipping")
return None return None
@ -158,7 +159,7 @@ class Importer:
raise EntryInvalidError(f"Model {model} not allowed") raise EntryInvalidError(f"Model {model} not allowed")
if issubclass(model, BaseMetaModel): if issubclass(model, BaseMetaModel):
serializer_class: type[Serializer] = model.serializer() serializer_class: type[Serializer] = model.serializer()
serializer = serializer_class(data=entry.get_attrs(self.__import)) serializer = serializer_class(data=entry.get_attrs(self._import))
try: try:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
except ValidationError as exc: except ValidationError as exc:
@ -172,7 +173,7 @@ class Importer:
# the full serializer for later usage # the full serializer for later usage
# Because a model might have multiple unique columns, we chain all identifiers together # Because a model might have multiple unique columns, we chain all identifiers together
# to create an OR query. # to create an OR query.
updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self.__import)) updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self._import))
for key, value in list(updated_identifiers.items()): for key, value in list(updated_identifiers.items()):
if isinstance(value, dict) and "pk" in value: if isinstance(value, dict) and "pk" in value:
del updated_identifiers[key] del updated_identifiers[key]
@ -211,7 +212,7 @@ class Importer:
model_instance.pk = updated_identifiers["pk"] model_instance.pk = updated_identifiers["pk"]
serializer_kwargs["instance"] = model_instance serializer_kwargs["instance"] = model_instance
try: try:
full_data = self.__update_pks_for_attrs(entry.get_attrs(self.__import)) full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import))
except ValueError as exc: except ValueError as exc:
raise EntryInvalidError(exc) from exc raise EntryInvalidError(exc) from exc
always_merger.merge(full_data, updated_identifiers) always_merger.merge(full_data, updated_identifiers)
@ -282,8 +283,8 @@ class Importer:
"""Validate loaded blueprint export, ensure all models are allowed """Validate loaded blueprint export, ensure all models are allowed
and serializers have no errors""" and serializers have no errors"""
self.logger.debug("Starting blueprint import validation") self.logger.debug("Starting blueprint import validation")
orig_import = deepcopy(self.__import) orig_import = deepcopy(self._import)
if self.__import.version != 1: if self._import.version != 1:
self.logger.warning("Invalid blueprint version") self.logger.warning("Invalid blueprint version")
return False, [] return False, []
with ( with (
@ -295,5 +296,42 @@ class Importer:
self.logger.debug("Blueprint validation failed") self.logger.debug("Blueprint validation failed")
for log in logs: for log in logs:
getattr(self.logger, log.get("log_level"))(**log) getattr(self.logger, log.get("log_level"))(**log)
self.__import = orig_import self._import = orig_import
return successful, logs return successful, logs
class Importer:
"""Importer capable of importing multi-document YAML"""
_importers: list[SingleDocumentImporter]
def __init__(self, *yaml_input: str, context: Optional[dict] = None):
docs = []
for doc in yaml_input:
docs += load_all(doc, BlueprintLoader)
self._importers = []
for doc in docs:
self._importers.append(SingleDocumentImporter(doc, context))
@property
def metadata(self) -> dict:
"""Get the merged metadata of all blueprints"""
metadata = {}
for importer in self._importers:
if importer._import.metadata:
always_merger.merge(metadata, asdict(importer._import.metadata))
return metadata
def apply(self) -> bool:
"""Apply all importers"""
return all(x.apply() for x in self._importers)
def validate(self) -> tuple[bool, list[EventDict]]:
"""Validate all importers"""
valid = []
events = []
for importer in self._importers:
_valid, _events = importer.validate()
valid.append(_valid)
events += _events
return all(valid), events

View file

@ -186,9 +186,8 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
return return
blueprint_content = instance.retrieve() blueprint_content = instance.retrieve()
file_hash = sha512(blueprint_content.encode()).hexdigest() file_hash = sha512(blueprint_content.encode()).hexdigest()
importer = Importer(blueprint_content, instance.context) importer = Importer(blueprint_content, context=instance.context)
if importer.blueprint.metadata: instance.metadata = importer.metadata
instance.metadata = asdict(importer.blueprint.metadata)
valid, logs = importer.validate() valid, logs = importer.validate()
if not valid: if not valid:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR