diff --git a/authentik/crypto/models.py b/authentik/crypto/models.py index 111adf079..037e8dffd 100644 --- a/authentik/crypto/models.py +++ b/authentik/crypto/models.py @@ -11,10 +11,13 @@ from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.x509 import Certificate, load_pem_x509_certificate from django.db import models from django.utils.translation import gettext_lazy as _ +from structlog.stdlib import get_logger from authentik.lib.models import CreatedUpdatedModel from authentik.managed.models import ManagedModel +LOGGER = get_logger() + class CertificateKeyPair(ManagedModel, CreatedUpdatedModel): """CertificateKeyPair that can be used for signing or encrypting if `key_data` @@ -62,7 +65,8 @@ class CertificateKeyPair(ManagedModel, CreatedUpdatedModel): password=None, backend=default_backend(), ) - except ValueError: + except ValueError as exc: + LOGGER.warning(exc) return None return self._private_key diff --git a/authentik/crypto/tasks.py b/authentik/crypto/tasks.py index 81e50219c..723a603ee 100644 --- a/authentik/crypto/tasks.py +++ b/authentik/crypto/tasks.py @@ -2,6 +2,9 @@ from glob import glob from pathlib import Path +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.x509.base import load_pem_x509_certificate from django.utils.translation import gettext_lazy as _ from structlog.stdlib import get_logger @@ -20,6 +23,22 @@ LOGGER = get_logger() MANAGED_DISCOVERED = "goauthentik.io/crypto/discovered/%s" +def ensure_private_key_valid(body: str): + """Attempt loading of an RSA Private key without password""" + load_pem_private_key( + str.encode("\n".join([x.strip() for x in body.split("\n")])), + password=None, + backend=default_backend(), + ) + return body + + +def ensure_certificate_valid(body: str): + """Attempt loading of a PEM-encoded certificate""" + load_pem_x509_certificate(body.encode("utf-8"), default_backend()) + return body + + @CELERY_APP.task(bind=True, base=MonitoredTask) @prefill_task def certificate_discovery(self: MonitoredTask): @@ -42,11 +61,11 @@ def certificate_discovery(self: MonitoredTask): with open(path, "r+", encoding="utf-8") as _file: body = _file.read() if "BEGIN RSA PRIVATE KEY" in body: - private_keys[cert_name] = body + private_keys[cert_name] = ensure_private_key_valid(body) else: - certs[cert_name] = body - except OSError as exc: - LOGGER.warning("Failed to open file", exc=exc, file=path) + certs[cert_name] = ensure_certificate_valid(body) + except (OSError, ValueError) as exc: + LOGGER.warning("Failed to open file or invalid format", exc=exc, file=path) discovered += 1 for name, cert_data in certs.items(): cert = CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % name).first()