providers/saml: import SAML Provider with all autogenerated mappings
This commit is contained in:
parent
188ef0f58f
commit
239af7048a
|
@ -10,7 +10,12 @@ from lxml import etree # nosec
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.crypto.models import CertificateKeyPair
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
from authentik.providers.saml.models import SAMLBindings, SAMLProvider
|
from authentik.flows.models import Flow
|
||||||
|
from authentik.providers.saml.models import (
|
||||||
|
SAMLBindings,
|
||||||
|
SAMLPropertyMapping,
|
||||||
|
SAMLProvider,
|
||||||
|
)
|
||||||
from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER
|
from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER
|
||||||
from authentik.sources.saml.processors.constants import (
|
from authentik.sources.saml.processors.constants import (
|
||||||
NS_MAP,
|
NS_MAP,
|
||||||
|
@ -48,10 +53,13 @@ class ServiceProviderMetadata:
|
||||||
|
|
||||||
signing_keypair: Optional[CertificateKeyPair] = None
|
signing_keypair: Optional[CertificateKeyPair] = None
|
||||||
|
|
||||||
def to_provider(self, name: str) -> SAMLProvider:
|
def to_provider(self, name: str, authorization_flow: Flow) -> SAMLProvider:
|
||||||
"""Create a SAMLProvider instance from the details. `name` is required,
|
"""Create a SAMLProvider instance from the details. `name` is required,
|
||||||
as depending on the metadata CertificateKeypairs might have to be created."""
|
as depending on the metadata CertificateKeypairs might have to be created."""
|
||||||
provider = SAMLProvider(name=name)
|
provider = SAMLProvider.objects.create(
|
||||||
|
name=name,
|
||||||
|
authorization_flow=authorization_flow,
|
||||||
|
)
|
||||||
provider.issuer = self.entity_id
|
provider.issuer = self.entity_id
|
||||||
provider.sp_binding = self.acs_binding
|
provider.sp_binding = self.acs_binding
|
||||||
provider.acs_url = self.acs_location
|
provider.acs_url = self.acs_location
|
||||||
|
@ -63,6 +71,11 @@ class ServiceProviderMetadata:
|
||||||
provider.signing_kp = CertificateKeyPair.objects.exclude(
|
provider.signing_kp = CertificateKeyPair.objects.exclude(
|
||||||
key_data__iexact=""
|
key_data__iexact=""
|
||||||
).first()
|
).first()
|
||||||
|
# Set all auto-generated Property-mappings as defaults
|
||||||
|
# They should provide a sane default for most applications:
|
||||||
|
provider.property_mappings.set(
|
||||||
|
SAMLPropertyMapping.objects.filter(name__startswith="Autogenerated")
|
||||||
|
)
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from authentik.providers.saml.models import SAMLBindings
|
from authentik.flows.models import Flow
|
||||||
|
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping
|
||||||
from authentik.providers.saml.processors.metadata_parser import (
|
from authentik.providers.saml.processors.metadata_parser import (
|
||||||
ServiceProviderMetadataParser,
|
ServiceProviderMetadataParser,
|
||||||
)
|
)
|
||||||
|
@ -65,18 +66,25 @@ bHlUY7ytSUTowXA=
|
||||||
class TestServiceProviderMetadataParser(TestCase):
|
class TestServiceProviderMetadataParser(TestCase):
|
||||||
"""Test ServiceProviderMetadataParser parsing and creation of SAML Provider"""
|
"""Test ServiceProviderMetadataParser parsing and creation of SAML Provider"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.flow = Flow.objects.first()
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
"""Test simple metadata without Singing"""
|
"""Test simple metadata without Singing"""
|
||||||
metadata = ServiceProviderMetadataParser().parse(METADATA_SIMPLE)
|
metadata = ServiceProviderMetadataParser().parse(METADATA_SIMPLE)
|
||||||
provider = metadata.to_provider("test")
|
provider = metadata.to_provider("test", self.flow)
|
||||||
self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs")
|
self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs")
|
||||||
self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata")
|
self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata")
|
||||||
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
|
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
provider.property_mappings,
|
||||||
|
SAMLPropertyMapping.objects.filter(name__startswith="Autogenerated"),
|
||||||
|
)
|
||||||
|
|
||||||
def test_with_signing_cert(self):
|
def test_with_signing_cert(self):
|
||||||
"""Test Metadata with signing cert"""
|
"""Test Metadata with signing cert"""
|
||||||
metadata = ServiceProviderMetadataParser().parse(METADATA_CERT)
|
metadata = ServiceProviderMetadataParser().parse(METADATA_CERT)
|
||||||
provider = metadata.to_provider("test")
|
provider = metadata.to_provider("test", self.flow)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs"
|
provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs"
|
||||||
)
|
)
|
||||||
|
|
|
@ -269,9 +269,10 @@ class MetadataImportView(LoginRequiredMixin, FormView):
|
||||||
metadata = ServiceProviderMetadataParser().parse(
|
metadata = ServiceProviderMetadataParser().parse(
|
||||||
form.cleaned_data["metadata"].read().decode()
|
form.cleaned_data["metadata"].read().decode()
|
||||||
)
|
)
|
||||||
provider = metadata.to_provider(form.cleaned_data["provider_name"])
|
metadata.to_provider(
|
||||||
provider.authorization_flow = form.cleaned_data["authorization_flow"]
|
form.cleaned_data["provider_name"],
|
||||||
provider.save()
|
form.cleaned_data["authorization_flow"],
|
||||||
|
)
|
||||||
messages.success(self.request, _("Successfully created Provider"))
|
messages.success(self.request, _("Successfully created Provider"))
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
LOGGER.warning(str(exc))
|
LOGGER.warning(str(exc))
|
||||||
|
|
Reference in New Issue