allow fetching of blueprint instance content to return multiple contents

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-12-22 22:49:27 +01:00 committed by Jens Langhammer
parent 6b78190093
commit 9dfe06fb18
No known key found for this signature in database
7 changed files with 24 additions and 16 deletions

View file

@ -18,7 +18,7 @@ class Command(BaseCommand):
"""Apply all blueprints in order, abort when one fails to import""" """Apply all blueprints in order, abort when one fails to import"""
for blueprint_path in options.get("blueprints", []): for blueprint_path in options.get("blueprints", []):
content = BlueprintInstance(path=blueprint_path).retrieve() content = BlueprintInstance(path=blueprint_path).retrieve()
importer = Importer(content) importer = Importer(*content)
valid, logs = importer.validate() valid, logs = importer.validate()
if not valid: if not valid:
for log in logs: for log in logs:

View file

@ -70,7 +70,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
enabled = models.BooleanField(default=True) enabled = models.BooleanField(default=True)
managed_models = ArrayField(models.TextField(), default=list) managed_models = ArrayField(models.TextField(), default=list)
def retrieve_oci(self) -> str: def retrieve_oci(self) -> list[str]:
"""Get blueprint from an OCI registry""" """Get blueprint from an OCI registry"""
client = BlueprintOCIClient(self.path.replace("oci://", "https://")) client = BlueprintOCIClient(self.path.replace("oci://", "https://"))
try: try:
@ -79,16 +79,16 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
except OCIException as exc: except OCIException as exc:
raise BlueprintRetrievalFailed(exc) from exc raise BlueprintRetrievalFailed(exc) from exc
def retrieve_file(self) -> str: def retrieve_file(self) -> list[str]:
"""Get blueprint from path""" """Get blueprint from path"""
try: try:
full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path)) full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path))
with full_path.open("r", encoding="utf-8") as _file: with full_path.open("r", encoding="utf-8") as _file:
return _file.read() return [_file.read()]
except (IOError, OSError) as exc: except (IOError, OSError) as exc:
raise BlueprintRetrievalFailed(exc) from exc raise BlueprintRetrievalFailed(exc) from exc
def retrieve(self) -> str: def retrieve(self) -> list[str]:
"""Retrieve blueprint contents""" """Retrieve blueprint contents"""
if self.path.startswith("oci://"): if self.path.startswith("oci://"):
return self.retrieve_oci() return self.retrieve_oci()

View file

@ -21,7 +21,7 @@ def apply_blueprint(*files: str):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
for file in files: for file in files:
content = BlueprintInstance(path=file).retrieve() content = BlueprintInstance(path=file).retrieve()
Importer(content).apply() Importer(*content).apply()
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper

View file

@ -29,7 +29,7 @@ class TestBlueprintOCI(TransactionTestCase):
BlueprintInstance( BlueprintInstance(
path="oci://ghcr.io/goauthentik/blueprints/test:latest" path="oci://ghcr.io/goauthentik/blueprints/test:latest"
).retrieve(), ).retrieve(),
"foo", ["foo"],
) )
def test_manifests_error(self): def test_manifests_error(self):

View file

@ -25,7 +25,8 @@ def blueprint_tester(file_name: Path) -> Callable:
def tester(self: TestPackaged): def tester(self: TestPackaged):
base = Path("blueprints/") base = Path("blueprints/")
rel_path = Path(file_name).relative_to(base) rel_path = Path(file_name).relative_to(base)
importer = Importer(BlueprintInstance(path=str(rel_path)).retrieve()) contents = BlueprintInstance(path=str(rel_path)).retrieve()
importer = Importer(*contents)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())

View file

@ -75,22 +75,29 @@ class BlueprintOCIClient:
raise OCIException(manifest["errors"]) raise OCIException(manifest["errors"])
return manifest return manifest
def fetch_blobs(self, manifest: dict[str, Any]): def fetch_blobs(self, manifest: dict[str, Any]) -> list[str]:
"""Fetch blob based on manifest info""" """Fetch blob based on manifest info"""
blob = None blob_digests = []
for layer in manifest.get("layers", []): for layer in manifest.get("layers", []):
if layer.get("mediaType", "") == OCI_MEDIA_TYPE: if layer.get("mediaType", "") == OCI_MEDIA_TYPE:
blob = layer.get("digest") blob_digests.append(layer.get("digest"))
self.logger.debug("Found layer with matching media type", blob=blob) if not blob_digests:
if not blob:
raise OCIException("Blob not found") raise OCIException("Blob not found")
bodies = []
for blob in blob_digests:
bodies.append(self.fetch_blob(blob))
self.logger.debug("Fetched blobs", count=len(bodies))
return bodies
def fetch_blob(self, digest: str) -> str:
"""Fetch blob based on manifest info"""
blob_request = self.client.NewRequest( blob_request = self.client.NewRequest(
"GET", "GET",
"/v2/<name>/blobs/<digest>", "/v2/<name>/blobs/<digest>",
WithDigest(blob), WithDigest(digest),
) )
try: try:
self.logger.debug("Fetching blob", digest=digest)
blob_response = self.client.Do(blob_request) blob_response = self.client.Do(blob_request)
blob_response.raise_for_status() blob_response.raise_for_status()
return blob_response.text return blob_response.text

View file

@ -185,8 +185,8 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
if not instance or not instance.enabled: if not instance or not instance.enabled:
return return
blueprint_content = instance.retrieve() blueprint_content = instance.retrieve()
file_hash = sha512(blueprint_content.encode()).hexdigest() file_hash = sha512("".join(blueprint_content).encode()).hexdigest()
importer = Importer(blueprint_content, context=instance.context) importer = Importer(*blueprint_content, context=instance.context)
instance.metadata = importer.metadata instance.metadata = importer.metadata
valid, logs = importer.validate() valid, logs = importer.validate()
if not valid: if not valid: