diff --git a/authentik/providers/saml/api.py b/authentik/providers/saml/api.py index ba2ff33e6..510f2a159 100644 --- a/authentik/providers/saml/api.py +++ b/authentik/providers/saml/api.py @@ -2,7 +2,8 @@ from xml.etree.ElementTree import ParseError # nosec from defusedxml.ElementTree import fromstring -from django.http.response import HttpResponse +from django.http.response import Http404, HttpResponse +from django.shortcuts import get_object_or_404 from django.urls import reverse from django.utils.translation import gettext_lazy as _ from drf_spectacular.types import OpenApiTypes @@ -114,7 +115,11 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): # pylint: disable=invalid-name, unused-argument def metadata(self, request: Request, pk: int) -> Response: """Return metadata as XML string""" - provider = self.get_object() + # We don't use self.get_object() on purpose as this view is un-authenticated + try: + provider = get_object_or_404(SAMLProvider, pk=pk) + except ValueError: + raise Http404 try: metadata = MetadataProcessor(provider, request).build_entity_descriptor() if "download" in request._request.GET: diff --git a/authentik/providers/saml/tests/test_api.py b/authentik/providers/saml/tests/test_api.py index 3e34edacc..ad7b0dacd 100644 --- a/authentik/providers/saml/tests/test_api.py +++ b/authentik/providers/saml/tests/test_api.py @@ -20,6 +20,7 @@ class TestSAMLProviderAPI(APITestCase): def test_metadata(self): """Test metadata export (normal)""" + self.client.logout() provider = SAMLProvider.objects.create( name="test", authorization_flow=Flow.objects.get( @@ -34,6 +35,7 @@ class TestSAMLProviderAPI(APITestCase): def test_metadata_download(self): """Test metadata export (download)""" + self.client.logout() provider = SAMLProvider.objects.create( name="test", authorization_flow=Flow.objects.get( @@ -50,6 +52,7 @@ class TestSAMLProviderAPI(APITestCase): def test_metadata_invalid(self): """Test metadata export (invalid)""" + self.client.logout() # Provider without application provider = SAMLProvider.objects.create( name="test", @@ -61,6 +64,10 @@ class TestSAMLProviderAPI(APITestCase): reverse("authentik_api:samlprovider-metadata", kwargs={"pk": provider.pk}), ) self.assertEqual(200, response.status_code) + response = self.client.get( + reverse("authentik_api:samlprovider-metadata", kwargs={"pk": "abc"}), + ) + self.assertEqual(404, response.status_code) def test_import_success(self): """Test metadata import (success case)"""