diff --git a/authentik/core/tests/utils.py b/authentik/core/tests/utils.py index 74fbc048a..84ef0845e 100644 --- a/authentik/core/tests/utils.py +++ b/authentik/core/tests/utils.py @@ -47,11 +47,11 @@ def create_test_tenant() -> Tenant: def create_test_cert() -> CertificateKeyPair: """Generate a certificate for testing""" - CertificateKeyPair.objects.filter(name="goauthentik.io").delete() builder = CertificateBuilder() builder.common_name = "goauthentik.io" builder.build( subject_alt_names=["goauthentik.io"], validity_days=360, ) + builder.name = generate_id() return builder.save() diff --git a/authentik/crypto/builder.py b/authentik/crypto/builder.py index 2881750eb..1316d2555 100644 --- a/authentik/crypto/builder.py +++ b/authentik/crypto/builder.py @@ -53,10 +53,7 @@ class CertificateBuilder: .subject_name( x509.Name( [ - x509.NameAttribute( - NameOID.COMMON_NAME, - self.common_name, - ), + x509.NameAttribute(NameOID.COMMON_NAME, self.common_name), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"), x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"), ] @@ -65,10 +62,7 @@ class CertificateBuilder: .issuer_name( x509.Name( [ - x509.NameAttribute( - NameOID.COMMON_NAME, - f"authentik {__version__}", - ), + x509.NameAttribute(NameOID.COMMON_NAME, f"authentik {__version__}"), ] ) ) diff --git a/authentik/providers/oauth2/tests/test_authorize.py b/authentik/providers/oauth2/tests/test_authorize.py index a5b2a8c7d..9f0aebb4e 100644 --- a/authentik/providers/oauth2/tests/test_authorize.py +++ b/authentik/providers/oauth2/tests/test_authorize.py @@ -3,7 +3,7 @@ from django.test import RequestFactory from django.urls import reverse from authentik.core.models import Application -from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow +from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.flows.challenge import ChallengeTypes from authentik.lib.generators import generate_id, generate_key from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError @@ -39,7 +39,7 @@ class TestAuthorize(OAuthTestCase): def test_request(self): """test request param""" OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id="test", authorization_flow=create_test_flow(), redirect_uris="http://local.invalid/Foo", @@ -59,7 +59,7 @@ class TestAuthorize(OAuthTestCase): def test_invalid_redirect_uri(self): """test missing/invalid redirect URI""" OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id="test", authorization_flow=create_test_flow(), redirect_uris="http://local.invalid", @@ -78,10 +78,55 @@ class TestAuthorize(OAuthTestCase): ) OAuthAuthorizationParams.from_request(request) + def test_invalid_redirect_uri_empty(self): + """test missing/invalid redirect URI""" + provider = OAuth2Provider.objects.create( + name=generate_id(), + client_id="test", + authorization_flow=create_test_flow(), + redirect_uris="", + ) + with self.assertRaises(RedirectUriError): + request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) + OAuthAuthorizationParams.from_request(request) + request = self.factory.get( + "/", + data={ + "response_type": "code", + "client_id": "test", + "redirect_uri": "+", + }, + ) + OAuthAuthorizationParams.from_request(request) + provider.refresh_from_db() + self.assertEqual(provider.redirect_uris, "+") + def test_invalid_redirect_uri_regex(self): """test missing/invalid redirect URI""" OAuth2Provider.objects.create( - name="test", + name=generate_id(), + client_id="test", + authorization_flow=create_test_flow(), + redirect_uris="http://local.invalid?", + ) + with self.assertRaises(RedirectUriError): + request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) + OAuthAuthorizationParams.from_request(request) + with self.assertRaises(RedirectUriError): + request = self.factory.get( + "/", + data={ + "response_type": "code", + "client_id": "test", + "redirect_uri": "http://localhost", + }, + ) + OAuthAuthorizationParams.from_request(request) + + def test_redirect_uri_invalid_regex(self): + """test missing/invalid redirect URI (invalid regex)""" + OAuth2Provider.objects.create( + name=generate_id(), client_id="test", authorization_flow=create_test_flow(), redirect_uris="+", @@ -103,7 +148,7 @@ class TestAuthorize(OAuthTestCase): def test_empty_redirect_uri(self): """test empty redirect URI (configure in provider)""" OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id="test", authorization_flow=create_test_flow(), ) @@ -123,7 +168,7 @@ class TestAuthorize(OAuthTestCase): def test_response_type(self): """test response_type""" OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id="test", authorization_flow=create_test_flow(), redirect_uris="http://local.invalid/Foo", @@ -201,7 +246,7 @@ class TestAuthorize(OAuthTestCase): """Test full authorization""" flow = create_test_flow() provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id="test", authorization_flow=flow, redirect_uris="foo://localhost", @@ -237,12 +282,12 @@ class TestAuthorize(OAuthTestCase): """Test full authorization""" flow = create_test_flow() provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id="test", client_secret=generate_key(), authorization_flow=flow, redirect_uris="http://localhost", - signing_key=create_test_cert(), + signing_key=self.keypair, ) Application.objects.create(name="app", slug="app", provider=provider) state = generate_id() @@ -281,12 +326,12 @@ class TestAuthorize(OAuthTestCase): """Test full authorization (form_post response)""" flow = create_test_flow() provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id="test", client_secret=generate_key(), authorization_flow=flow, redirect_uris="http://localhost", - signing_key=create_test_cert(), + signing_key=self.keypair, ) Application.objects.create(name="app", slug="app", provider=provider) state = generate_id() diff --git a/authentik/providers/oauth2/tests/test_token.py b/authentik/providers/oauth2/tests/test_token.py index bc1dba975..0da18abcc 100644 --- a/authentik/providers/oauth2/tests/test_token.py +++ b/authentik/providers/oauth2/tests/test_token.py @@ -5,7 +5,7 @@ from django.test import RequestFactory from django.urls import reverse from authentik.core.models import Application -from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow +from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.events.models import Event, EventAction from authentik.lib.generators import generate_id, generate_key from authentik.providers.oauth2.constants import ( @@ -24,17 +24,17 @@ class TestToken(OAuthTestCase): def setUp(self) -> None: super().setUp() self.factory = RequestFactory() - self.app = Application.objects.create(name="test", slug="test") + self.app = Application.objects.create(name=generate_id(), slug="test") def test_request_auth_code(self): """test request param""" provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id=generate_id(), client_secret=generate_key(), authorization_flow=create_test_flow(), redirect_uris="http://testserver", - signing_key=create_test_cert(), + signing_key=self.keypair, ) header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = create_test_admin_user() @@ -56,12 +56,12 @@ class TestToken(OAuthTestCase): def test_request_auth_code_invalid(self): """test request param""" provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id=generate_id(), client_secret=generate_key(), authorization_flow=create_test_flow(), redirect_uris="http://testserver", - signing_key=create_test_cert(), + signing_key=self.keypair, ) header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() request = self.factory.post( @@ -79,12 +79,12 @@ class TestToken(OAuthTestCase): def test_request_refresh_token(self): """test request param""" provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id=generate_id(), client_secret=generate_key(), authorization_flow=create_test_flow(), redirect_uris="http://local.invalid", - signing_key=create_test_cert(), + signing_key=self.keypair, ) header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = create_test_admin_user() @@ -108,12 +108,12 @@ class TestToken(OAuthTestCase): def test_auth_code_view(self): """test request param""" provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id=generate_id(), client_secret=generate_key(), authorization_flow=create_test_flow(), redirect_uris="http://local.invalid", - signing_key=create_test_cert(), + signing_key=self.keypair, ) # Needs to be assigned to an application for iss to be set self.app.provider = provider @@ -150,12 +150,12 @@ class TestToken(OAuthTestCase): def test_refresh_token_view(self): """test request param""" provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id=generate_id(), client_secret=generate_key(), authorization_flow=create_test_flow(), redirect_uris="http://local.invalid", - signing_key=create_test_cert(), + signing_key=self.keypair, ) # Needs to be assigned to an application for iss to be set self.app.provider = provider @@ -199,12 +199,12 @@ class TestToken(OAuthTestCase): def test_refresh_token_view_invalid_origin(self): """test request param""" provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id=generate_id(), client_secret=generate_key(), authorization_flow=create_test_flow(), redirect_uris="http://local.invalid", - signing_key=create_test_cert(), + signing_key=self.keypair, ) header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = create_test_admin_user() @@ -244,12 +244,12 @@ class TestToken(OAuthTestCase): def test_refresh_token_revoke(self): """test request param""" provider = OAuth2Provider.objects.create( - name="test", + name=generate_id(), client_id=generate_id(), client_secret=generate_key(), authorization_flow=create_test_flow(), redirect_uris="http://testserver", - signing_key=create_test_cert(), + signing_key=self.keypair, ) # Needs to be assigned to an application for iss to be set self.app.provider = provider diff --git a/authentik/providers/oauth2/tests/utils.py b/authentik/providers/oauth2/tests/utils.py index 6c36dee17..85c1dc848 100644 --- a/authentik/providers/oauth2/tests/utils.py +++ b/authentik/providers/oauth2/tests/utils.py @@ -2,12 +2,15 @@ from django.test import TestCase from jwt import decode +from authentik.core.tests.utils import create_test_cert +from authentik.crypto.models import CertificateKeyPair from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider, RefreshToken class OAuthTestCase(TestCase): """OAuth test helpers""" + keypair: CertificateKeyPair required_jwt_keys = [ "exp", "iat", @@ -17,6 +20,11 @@ class OAuthTestCase(TestCase): "iss", ] + @classmethod + def setUpClass(cls) -> None: + cls.keypair = create_test_cert() + super().setUpClass() + def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): """Validate that all required fields are set""" key, alg = provider.get_jwt_key() diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index 634b1abc2..aeeb6d82a 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import timedelta from re import error as RegexError -from re import escape, fullmatch +from re import fullmatch from typing import Optional from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit from uuid import uuid4 @@ -181,7 +181,7 @@ class OAuthAuthorizationParams: if self.provider.redirect_uris == "": LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) - self.provider.redirect_uris = escape(self.redirect_uri) + self.provider.redirect_uris = self.redirect_uri self.provider.save() allowed_redirect_urls = self.provider.redirect_uris.split() @@ -194,14 +194,20 @@ class OAuthAuthorizationParams: try: if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): LOGGER.warning( - "Invalid redirect uri", + "Invalid redirect uri (regex comparison)", redirect_uri=self.redirect_uri, expected=allowed_redirect_urls, ) raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) except RegexError as exc: - LOGGER.warning("Invalid regular expression configured", exc=exc) - raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) + LOGGER.info("Failed to parse regular expression, checking directly", exc=exc) + if not any(x == self.redirect_uri for x in allowed_redirect_urls): + LOGGER.warning( + "Invalid redirect uri (strict comparison)", + redirect_uri=self.redirect_uri, + expected=allowed_redirect_urls, + ) + raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) if self.request: raise AuthorizeError( self.redirect_uri, "request_not_supported", self.grant_type, self.state diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index d286da06b..a3c8735e2 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -154,7 +154,7 @@ class TokenParams: try: if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): LOGGER.warning( - "Invalid redirect uri", + "Invalid redirect uri (regex comparison)", redirect_uri=self.redirect_uri, expected=allowed_redirect_urls, ) @@ -167,13 +167,19 @@ class TokenParams: ).from_http(request) raise TokenError("invalid_client") except RegexError as exc: - LOGGER.warning("Invalid regular expression configured", exc=exc) - Event.new( - EventAction.CONFIGURATION_ERROR, - message="Invalid redirect_uri RegEx configured", - provider=self.provider, - ).from_http(request) - raise TokenError("invalid_client") + LOGGER.info("Failed to parse regular expression, checking directly", exc=exc) + if not any(x == self.redirect_uri for x in allowed_redirect_urls): + LOGGER.warning( + "Invalid redirect uri (strict comparison)", + redirect_uri=self.redirect_uri, + expected=allowed_redirect_urls, + ) + Event.new( + EventAction.CONFIGURATION_ERROR, + message="Invalid redirect_uri configured", + provider=self.provider, + ).from_http(request) + raise TokenError("invalid_client") try: self.authorization_code = AuthorizationCode.objects.get(code=raw_code)