diff --git a/authentik/blueprints/tests/fixtures/test.json b/authentik/blueprints/tests/fixtures/test.json new file mode 100644 index 000000000..809793285 --- /dev/null +++ b/authentik/blueprints/tests/fixtures/test.json @@ -0,0 +1,44 @@ +{ + "$schema": "https://goauthentik.io/blueprints/schema.json", + "version": 1, + "metadata": { + "name": "test-json" + }, + "entries": [ + { + "model": "authentik_providers_oauth2.oauth2provider", + "id": "provider", + "identifiers": { + "name": "grafana-json" + }, + "attrs": { + "authorization_flow": { + "goauthentik.io/yaml-key": "!Find", + "args": [ + "authentik_flows.flow", + [ + "pk", + { + "goauthentik.io/yaml-key": "!Context", + "args": "flow" + } + ] + ] + } + } + }, + { + "model": "authentik_core.application", + "identifiers": { + "slug": "test-json" + }, + "attrs": { + "name": "test-json", + "provider": { + "goauthentik.io/yaml-key": "!KeyOf", + "args": "provider" + } + } + } + ] +} diff --git a/authentik/blueprints/tests/test_v1_json.py b/authentik/blueprints/tests/test_v1_json.py new file mode 100644 index 000000000..d34a21bb8 --- /dev/null +++ b/authentik/blueprints/tests/test_v1_json.py @@ -0,0 +1,22 @@ +"""Test blueprints v1 JSON""" +from django.test import TransactionTestCase + +from authentik.blueprints.v1.importer import JSONStringImporter +from authentik.core.tests.utils import create_test_flow +from authentik.lib.tests.utils import load_fixture + + +class TestBlueprintsV1JSON(TransactionTestCase): + """Test Blueprints""" + + def test_import(self): + """Test JSON Import""" + test_flow = create_test_flow() + importer = JSONStringImporter( + load_fixture("fixtures/test.json"), + { + "flow": str(test_flow.pk), + }, + ) + self.assertTrue(importer.validate()[0]) + self.assertTrue(importer.apply()) diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index 628eb6fc1..7b043c749 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -1,7 +1,6 @@ """Blueprint importer""" from contextlib import contextmanager from copy import deepcopy -from json import loads from typing import Any, Optional from dacite.config import Config @@ -350,7 +349,7 @@ class JSONStringImporter(Importer): """Importer that also parses from JSON string""" def __init__(self, json_import: str, context: dict | None = None): - import_dict = loads(json_import, cls=BlueprintJSONDecoder) + import_dict = load(json_import, BlueprintJSONDecoder) try: _import = from_dict( Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState]) diff --git a/authentik/blueprints/v1/json_parser.py b/authentik/blueprints/v1/json_parser.py index 38685ae4f..1ca670e65 100644 --- a/authentik/blueprints/v1/json_parser.py +++ b/authentik/blueprints/v1/json_parser.py @@ -1,9 +1,9 @@ """Blueprint JSON decoder""" -from rest_framework.parsers import JSONParser -from json import JSONDecoder +from collections.abc import Hashable from typing import Any -from yaml import ScalarNode, SequenceNode +from rest_framework.parsers import JSONParser +from yaml.nodes import MappingNode from authentik.blueprints.v1.common import BlueprintLoader, YAMLTag, yaml_key_map @@ -11,17 +11,46 @@ TAG_KEY = "goauthentik.io/yaml-key" ARGS_KEY = "args" -class BlueprintJSONDecoder(JSONDecoder): - """Blueprint JSON decoder, allows using tag logic - when using JSON data (e.g. through the API)""" +class BlueprintJSONDecoder(BlueprintLoader): + """Blueprint JSON decoder, allows using tag logic when using JSON data (e.g. through the API, + when YAML tags are not available). + + This is still based on a YAML Loader, since all the YAML Tag constructors expect *Node objects + from YAML, this makes things a lot easier.""" - dummy_loader: BlueprintLoader tag_map: dict[str, type[YAMLTag]] def __init__(self, *args, **kwargs): - super().__init__(*args, object_hook=self.object_hook, **kwargs) - self.dummy_loader = BlueprintLoader("") + super().__init__(*args, **kwargs) self.tag_map = yaml_key_map() + self.add_constructor("tag:yaml.org,2002:map", BlueprintJSONDecoder.construct_yaml_map) + + def construct_yaml_map(self, node): + """The original construct_yaml_map creates a dict, yields it, then updates it, + which is probably some sort of performance optimisation, however it breaks here + when we don't return a dict from the `construct_mapping` function""" + value = self.construct_mapping(node) + yield value + + def construct_mapping(self, node: MappingNode, deep: bool = False) -> dict[Hashable, Any]: + """Check if the mapping has a special key and create an in-place YAML tag for it, + and return that instead of the actual dict""" + parsed = super().construct_mapping(node, deep=deep) + if TAG_KEY not in parsed: + return parsed + tag_cls = self.parse_yaml_tag(parsed) + if not tag_cls: + return parsed + # MappingNode's value is a list of tuples where the tuples + # consist of (KeyNode, ValueNode) + # so this filters out the value node for `args` + raw_args_pair = [x for x in node.value if x[0].value == ARGS_KEY] + if len(raw_args_pair) < 1: + return parsed + # Get the value of the first Node in the pair we get from above + # where the value isn't `args`, i.e. the actual argument data + raw_args_data = [x for x in raw_args_pair[0] if x.value != ARGS_KEY][0] + return tag_cls(self, raw_args_data) def parse_yaml_tag(self, data: dict) -> YAMLTag | None: """parse the tag""" @@ -31,30 +60,6 @@ class BlueprintJSONDecoder(JSONDecoder): return None return tag_cls - def parse_yaml_tag_args(self, data: Any) -> Any: - """Parse args into their yaml equivalent""" - if data: - if isinstance(data, list): - return SequenceNode( - "tag:yaml.org,2002:seq", [self.parse_yaml_tag_args(x) for x in data] - ) - if isinstance(data, str): - return ScalarNode("tag:yaml.org,2002:str", data) - if isinstance(data, int): - return ScalarNode("tag:yaml.org,2002:int", data) - if isinstance(data, float): - return ScalarNode("tag:yaml.org,2002:float", data) - return None - - def object_hook(self, data: dict) -> dict | Any: - if TAG_KEY not in data: - return data - tag_cls = self.parse_yaml_tag(data) - if not tag_cls: - return data - tag_args = self.parse_yaml_tag_args(data.get(ARGS_KEY, [])) - return tag_cls(self.dummy_loader, tag_args) - class BlueprintJSONParser(JSONParser): """Wrapper around the rest_framework JSON parser that uses the `BlueprintJSONDecoder`"""