providers/saml: import SAML Provider with all autogenerated mappings

This commit is contained in:
Jens Langhammer 2021-01-28 23:32:36 +01:00
parent 188ef0f58f
commit 239af7048a
3 changed files with 31 additions and 9 deletions

View File

@ -10,7 +10,12 @@ from lxml import etree # nosec
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.providers.saml.models import SAMLBindings, SAMLProvider from authentik.flows.models import Flow
from authentik.providers.saml.models import (
SAMLBindings,
SAMLPropertyMapping,
SAMLProvider,
)
from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER
from authentik.sources.saml.processors.constants import ( from authentik.sources.saml.processors.constants import (
NS_MAP, NS_MAP,
@ -48,10 +53,13 @@ class ServiceProviderMetadata:
signing_keypair: Optional[CertificateKeyPair] = None signing_keypair: Optional[CertificateKeyPair] = None
def to_provider(self, name: str) -> SAMLProvider: def to_provider(self, name: str, authorization_flow: Flow) -> SAMLProvider:
"""Create a SAMLProvider instance from the details. `name` is required, """Create a SAMLProvider instance from the details. `name` is required,
as depending on the metadata CertificateKeypairs might have to be created.""" as depending on the metadata CertificateKeypairs might have to be created."""
provider = SAMLProvider(name=name) provider = SAMLProvider.objects.create(
name=name,
authorization_flow=authorization_flow,
)
provider.issuer = self.entity_id provider.issuer = self.entity_id
provider.sp_binding = self.acs_binding provider.sp_binding = self.acs_binding
provider.acs_url = self.acs_location provider.acs_url = self.acs_location
@ -63,6 +71,11 @@ class ServiceProviderMetadata:
provider.signing_kp = CertificateKeyPair.objects.exclude( provider.signing_kp = CertificateKeyPair.objects.exclude(
key_data__iexact="" key_data__iexact=""
).first() ).first()
# Set all auto-generated Property-mappings as defaults
# They should provide a sane default for most applications:
provider.property_mappings.set(
SAMLPropertyMapping.objects.filter(name__startswith="Autogenerated")
)
return provider return provider

View File

@ -3,7 +3,8 @@
from django.test import TestCase from django.test import TestCase
from authentik.providers.saml.models import SAMLBindings from authentik.flows.models import Flow
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping
from authentik.providers.saml.processors.metadata_parser import ( from authentik.providers.saml.processors.metadata_parser import (
ServiceProviderMetadataParser, ServiceProviderMetadataParser,
) )
@ -65,18 +66,25 @@ bHlUY7ytSUTowXA=
class TestServiceProviderMetadataParser(TestCase): class TestServiceProviderMetadataParser(TestCase):
"""Test ServiceProviderMetadataParser parsing and creation of SAML Provider""" """Test ServiceProviderMetadataParser parsing and creation of SAML Provider"""
def setUp(self) -> None:
self.flow = Flow.objects.first()
def test_simple(self): def test_simple(self):
"""Test simple metadata without Singing""" """Test simple metadata without Singing"""
metadata = ServiceProviderMetadataParser().parse(METADATA_SIMPLE) metadata = ServiceProviderMetadataParser().parse(METADATA_SIMPLE)
provider = metadata.to_provider("test") provider = metadata.to_provider("test", self.flow)
self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs") self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs")
self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata") self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata")
self.assertEqual(provider.sp_binding, SAMLBindings.POST) self.assertEqual(provider.sp_binding, SAMLBindings.POST)
self.assertQuerysetEqual(
provider.property_mappings,
SAMLPropertyMapping.objects.filter(name__startswith="Autogenerated"),
)
def test_with_signing_cert(self): def test_with_signing_cert(self):
"""Test Metadata with signing cert""" """Test Metadata with signing cert"""
metadata = ServiceProviderMetadataParser().parse(METADATA_CERT) metadata = ServiceProviderMetadataParser().parse(METADATA_CERT)
provider = metadata.to_provider("test") provider = metadata.to_provider("test", self.flow)
self.assertEqual( self.assertEqual(
provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs" provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs"
) )

View File

@ -269,9 +269,10 @@ class MetadataImportView(LoginRequiredMixin, FormView):
metadata = ServiceProviderMetadataParser().parse( metadata = ServiceProviderMetadataParser().parse(
form.cleaned_data["metadata"].read().decode() form.cleaned_data["metadata"].read().decode()
) )
provider = metadata.to_provider(form.cleaned_data["provider_name"]) metadata.to_provider(
provider.authorization_flow = form.cleaned_data["authorization_flow"] form.cleaned_data["provider_name"],
provider.save() form.cleaned_data["authorization_flow"],
)
messages.success(self.request, _("Successfully created Provider")) messages.success(self.request, _("Successfully created Provider"))
except ValueError as exc: except ValueError as exc:
LOGGER.warning(str(exc)) LOGGER.warning(str(exc))