diff --git a/authentik/providers/saml/processors/metadata_parser.py b/authentik/providers/saml/processors/metadata_parser.py index e2168501d..cd7797ab2 100644 --- a/authentik/providers/saml/processors/metadata_parser.py +++ b/authentik/providers/saml/processors/metadata_parser.py @@ -10,7 +10,12 @@ from lxml import etree # nosec from structlog.stdlib import get_logger from authentik.crypto.models import CertificateKeyPair -from authentik.providers.saml.models import SAMLBindings, SAMLProvider +from authentik.flows.models import Flow +from authentik.providers.saml.models import ( + SAMLBindings, + SAMLPropertyMapping, + SAMLProvider, +) from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER from authentik.sources.saml.processors.constants import ( NS_MAP, @@ -48,10 +53,13 @@ class ServiceProviderMetadata: signing_keypair: Optional[CertificateKeyPair] = None - def to_provider(self, name: str) -> SAMLProvider: + def to_provider(self, name: str, authorization_flow: Flow) -> SAMLProvider: """Create a SAMLProvider instance from the details. `name` is required, as depending on the metadata CertificateKeypairs might have to be created.""" - provider = SAMLProvider(name=name) + provider = SAMLProvider.objects.create( + name=name, + authorization_flow=authorization_flow, + ) provider.issuer = self.entity_id provider.sp_binding = self.acs_binding provider.acs_url = self.acs_location @@ -63,6 +71,11 @@ class ServiceProviderMetadata: provider.signing_kp = CertificateKeyPair.objects.exclude( key_data__iexact="" ).first() + # Set all auto-generated Property-mappings as defaults + # They should provide a sane default for most applications: + provider.property_mappings.set( + SAMLPropertyMapping.objects.filter(name__startswith="Autogenerated") + ) return provider diff --git a/authentik/providers/saml/tests/test_metadata.py b/authentik/providers/saml/tests/test_metadata.py index bb50901cb..396f3b463 100644 --- a/authentik/providers/saml/tests/test_metadata.py +++ b/authentik/providers/saml/tests/test_metadata.py @@ -3,7 +3,8 @@ from django.test import TestCase -from authentik.providers.saml.models import SAMLBindings +from authentik.flows.models import Flow +from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping from authentik.providers.saml.processors.metadata_parser import ( ServiceProviderMetadataParser, ) @@ -65,18 +66,25 @@ bHlUY7ytSUTowXA= class TestServiceProviderMetadataParser(TestCase): """Test ServiceProviderMetadataParser parsing and creation of SAML Provider""" + def setUp(self) -> None: + self.flow = Flow.objects.first() + def test_simple(self): """Test simple metadata without Singing""" metadata = ServiceProviderMetadataParser().parse(METADATA_SIMPLE) - provider = metadata.to_provider("test") + provider = metadata.to_provider("test", self.flow) self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs") self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata") self.assertEqual(provider.sp_binding, SAMLBindings.POST) + self.assertQuerysetEqual( + provider.property_mappings, + SAMLPropertyMapping.objects.filter(name__startswith="Autogenerated"), + ) def test_with_signing_cert(self): """Test Metadata with signing cert""" metadata = ServiceProviderMetadataParser().parse(METADATA_CERT) - provider = metadata.to_provider("test") + provider = metadata.to_provider("test", self.flow) self.assertEqual( provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs" ) diff --git a/authentik/providers/saml/views.py b/authentik/providers/saml/views.py index ad1f4f50f..98bb446f8 100644 --- a/authentik/providers/saml/views.py +++ b/authentik/providers/saml/views.py @@ -269,9 +269,10 @@ class MetadataImportView(LoginRequiredMixin, FormView): metadata = ServiceProviderMetadataParser().parse( form.cleaned_data["metadata"].read().decode() ) - provider = metadata.to_provider(form.cleaned_data["provider_name"]) - provider.authorization_flow = form.cleaned_data["authorization_flow"] - provider.save() + metadata.to_provider( + form.cleaned_data["provider_name"], + form.cleaned_data["authorization_flow"], + ) messages.success(self.request, _("Successfully created Provider")) except ValueError as exc: LOGGER.warning(str(exc))