This repository has been archived on 2024-05-31. You can view files and clone it, but cannot push or open issues or pull requests.
authentik/authentik/crypto/tests.py
Jens L 967a38b7ac
crypto: make name field unique to prevent double certs (#5406)
* crypto: make name field unique to prevent double certs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2023-04-28 14:35:59 +03:00

261 lines
9.8 KiB
Python

"""Crypto tests"""
import datetime
from json import loads
from os import makedirs
from tempfile import TemporaryDirectory
from cryptography.x509.extensions import SubjectAlternativeName
from cryptography.x509.general_name import DNSName
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.api.used_by import DeleteAction
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
from authentik.crypto.api import CertificateKeyPairSerializer
from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair
from authentik.crypto.tasks import MANAGED_DISCOVERED, certificate_discovery
from authentik.lib.config import CONFIG
from authentik.lib.generators import generate_id, generate_key
from authentik.providers.oauth2.models import OAuth2Provider
class TestCrypto(APITestCase):
"""Test Crypto validation"""
def test_model_private(self):
"""Test model private key"""
cert = CertificateKeyPair.objects.create(
name="test",
certificate_data="foo",
key_data="foo",
)
self.assertIsNone(cert.private_key)
def test_serializer(self):
"""Test API Validation"""
keypair = create_test_cert()
self.assertTrue(
CertificateKeyPairSerializer(
instance=keypair,
data={
"name": keypair.name,
"certificate_data": keypair.certificate_data,
"key_data": keypair.key_data,
},
).is_valid()
)
self.assertFalse(
CertificateKeyPairSerializer(
instance=keypair,
data={
"name": keypair.name,
"certificate_data": "test",
"key_data": "test",
},
).is_valid()
)
def test_builder(self):
"""Test Builder"""
name = generate_id()
builder = CertificateBuilder(name)
with self.assertRaises(ValueError):
builder.save()
builder.build(
subject_alt_names=[],
validity_days=3,
)
instance = builder.save()
now = datetime.datetime.today()
self.assertEqual(instance.name, name)
self.assertEqual((instance.certificate.not_valid_after - now).days, 2)
def test_builder_api(self):
"""Test Builder (via API)"""
self.client.force_login(create_test_admin_user())
name = generate_id()
self.client.post(
reverse("authentik_api:certificatekeypair-generate"),
data={"common_name": name, "subject_alt_name": "bar,baz", "validity_days": 3},
)
key = CertificateKeyPair.objects.filter(name=name).first()
self.assertIsNotNone(key)
ext: SubjectAlternativeName = key.certificate.extensions[0].value
self.assertIsInstance(ext, SubjectAlternativeName)
self.assertIsInstance(ext[0], DNSName)
self.assertEqual(ext[0].value, "bar")
self.assertIsInstance(ext[1], DNSName)
self.assertEqual(ext[1].value, "baz")
def test_builder_api_empty_san(self):
"""Test Builder (via API)"""
self.client.force_login(create_test_admin_user())
name = generate_id()
self.client.post(
reverse("authentik_api:certificatekeypair-generate"),
data={"common_name": name, "subject_alt_name": "", "validity_days": 3},
)
key = CertificateKeyPair.objects.filter(name=name).first()
self.assertIsNotNone(key)
self.assertEqual(len(key.certificate.extensions), 0)
def test_builder_api_empty_san_multiple(self):
"""Test Builder (via API)"""
self.client.force_login(create_test_admin_user())
name = generate_id()
self.client.post(
reverse("authentik_api:certificatekeypair-generate"),
data={"common_name": name, "subject_alt_name": ", ", "validity_days": 3},
)
key = CertificateKeyPair.objects.filter(name=name).first()
self.assertIsNotNone(key)
self.assertEqual(len(key.certificate.extensions), 0)
def test_builder_api_invalid(self):
"""Test Builder (via API) (invalid)"""
self.client.force_login(create_test_admin_user())
response = self.client.post(
reverse("authentik_api:certificatekeypair-generate"),
data={},
)
self.assertEqual(response.status_code, 400)
def test_list(self):
"""Test API List"""
cert = create_test_cert()
self.client.force_login(create_test_admin_user())
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-list",
)
+ f"?name={cert.name}"
)
self.assertEqual(200, response.status_code)
body = loads(response.content.decode())
api_cert = [x for x in body["results"] if x["name"] == cert.name][0]
self.assertEqual(api_cert["fingerprint_sha1"], cert.fingerprint_sha1)
self.assertEqual(api_cert["fingerprint_sha256"], cert.fingerprint_sha256)
def test_list_without_details(self):
"""Test API List (no details)"""
cert = create_test_cert()
self.client.force_login(create_test_admin_user())
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-list",
)
+ f"?name={cert.name}&include_details=false"
)
self.assertEqual(200, response.status_code)
body = loads(response.content.decode())
api_cert = [x for x in body["results"] if x["name"] == cert.name][0]
self.assertEqual(api_cert["fingerprint_sha1"], None)
self.assertEqual(api_cert["fingerprint_sha256"], None)
def test_certificate_download(self):
"""Test certificate export (download)"""
self.client.force_login(create_test_admin_user())
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(200, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
)
+ "?download",
)
self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response)
def test_private_key_download(self):
"""Test private_key export (download)"""
self.client.force_login(create_test_admin_user())
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(200, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
)
+ "?download",
)
self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response)
def test_used_by(self):
"""Test used_by endpoint"""
self.client.force_login(create_test_admin_user())
keypair = create_test_cert()
provider = OAuth2Provider.objects.create(
name="test",
client_id="test",
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://localhost",
signing_key=keypair,
)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-used-by",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(200, response.status_code)
self.assertJSONEqual(
response.content.decode(),
[
{
"app": "authentik_providers_oauth2",
"model_name": "oauth2provider",
"pk": str(provider.pk),
"name": str(provider),
"action": DeleteAction.SET_NULL.name,
}
],
)
def test_discovery(self):
"""Test certificate discovery"""
name = generate_id()
builder = CertificateBuilder(name)
with self.assertRaises(ValueError):
builder.save()
builder.build(
subject_alt_names=[],
validity_days=3,
)
with TemporaryDirectory() as temp_dir:
with open(f"{temp_dir}/foo.pem", "w+", encoding="utf-8") as _cert:
_cert.write(builder.certificate)
with open(f"{temp_dir}/foo.key", "w+", encoding="utf-8") as _key:
_key.write(builder.private_key)
makedirs(f"{temp_dir}/foo.bar", exist_ok=True)
with open(f"{temp_dir}/foo.bar/fullchain.pem", "w+", encoding="utf-8") as _cert:
_cert.write(builder.certificate)
with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key:
_key.write(builder.private_key)
with CONFIG.patch("cert_discovery_dir", temp_dir):
certificate_discovery() # pylint: disable=no-value-for-parameter
keypair: CertificateKeyPair = CertificateKeyPair.objects.filter(
managed=MANAGED_DISCOVERED % "foo"
).first()
self.assertIsNotNone(keypair)
self.assertIsNotNone(keypair.certificate)
self.assertIsNotNone(keypair.private_key)
self.assertTrue(
CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % "foo.bar").exists()
)