diff --git a/authentik/providers/oauth2/tests/test_authorize.py b/authentik/providers/oauth2/tests/test_authorize.py index 9407b13e9..da169ed4e 100644 --- a/authentik/providers/oauth2/tests/test_authorize.py +++ b/authentik/providers/oauth2/tests/test_authorize.py @@ -1,8 +1,7 @@ """Test authorize view""" -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from django.utils.encoding import force_str -from jwt import decode from authentik.core.models import Application, User from authentik.flows.challenge import ChallengeTypes @@ -22,10 +21,11 @@ from authentik.providers.oauth2.models import ( OAuth2Provider, RefreshToken, ) +from authentik.providers.oauth2.tests.utils import OAuthTestCase from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams -class TestAuthorize(TestCase): +class TestAuthorize(OAuthTestCase): """Test authorize view""" def setUp(self) -> None: @@ -238,23 +238,4 @@ class TestAuthorize(TestCase): ), }, ) - jwt = decode( - token.access_token, - provider.client_secret, - algorithms=[provider.jwt_alg], - audience=provider.client_id, - ) - self.assertIsNotNone(jwt["exp"]) - self.assertIsNotNone(jwt["iat"]) - self.assertIsNotNone(jwt["auth_time"]) - self.assertIsNotNone(jwt["acr"]) - self.assertIsNotNone(jwt["sub"]) - self.assertIsNotNone(jwt["iss"]) - # Check id_token - id_token = token.id_token.to_dict() - self.assertIsNotNone(id_token["exp"]) - self.assertIsNotNone(id_token["iat"]) - self.assertIsNotNone(id_token["auth_time"]) - self.assertIsNotNone(id_token["acr"]) - self.assertIsNotNone(id_token["sub"]) - self.assertIsNotNone(id_token["iss"]) + self.validate_jwt(token, provider) diff --git a/authentik/providers/oauth2/tests/test_token.py b/authentik/providers/oauth2/tests/test_token.py index dccd9c10c..fc9c39c67 100644 --- a/authentik/providers/oauth2/tests/test_token.py +++ b/authentik/providers/oauth2/tests/test_token.py @@ -1,11 +1,11 @@ """Test token view""" from base64 import b64encode -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from django.utils.encoding import force_str -from authentik.core.models import User +from authentik.core.models import Application, User from authentik.flows.models import Flow from authentik.providers.oauth2.constants import ( GRANT_TYPE_AUTHORIZATION_CODE, @@ -20,15 +20,17 @@ from authentik.providers.oauth2.models import ( OAuth2Provider, RefreshToken, ) +from authentik.providers.oauth2.tests.utils import OAuthTestCase from authentik.providers.oauth2.views.token import TokenParams -class TestToken(TestCase): +class TestToken(OAuthTestCase): """Test token view""" def setUp(self) -> None: super().setUp() self.factory = RequestFactory() + self.app = Application.objects.create(name="test", slug="test") def test_request_auth_code(self): """test request param""" @@ -97,12 +99,15 @@ class TestToken(TestCase): authorization_flow=Flow.objects.first(), redirect_uris="http://local.invalid", ) + # Needs to be assigned to an application for iss to be set + self.app.provider = provider + self.app.save() header = b64encode( f"{provider.client_id}:{provider.client_secret}".encode() ).decode() user = User.objects.get(username="akadmin") code = AuthorizationCode.objects.create( - code="foobar", provider=provider, user=user + code="foobar", provider=provider, user=user, is_open_id=True ) response = self.client.post( reverse("authentik_providers_oauth2:token"), @@ -126,6 +131,7 @@ class TestToken(TestCase): ), }, ) + self.validate_jwt(new_token, provider) def test_refresh_token_view(self): """test request param""" @@ -136,6 +142,9 @@ class TestToken(TestCase): authorization_flow=Flow.objects.first(), redirect_uris="http://local.invalid", ) + # Needs to be assigned to an application for iss to be set + self.app.provider = provider + self.app.save() header = b64encode( f"{provider.client_id}:{provider.client_secret}".encode() ).decode() @@ -174,6 +183,7 @@ class TestToken(TestCase): ), }, ) + self.validate_jwt(new_token, provider) def test_refresh_token_view_invalid_origin(self): """test request param""" diff --git a/authentik/providers/oauth2/tests/utils.py b/authentik/providers/oauth2/tests/utils.py new file mode 100644 index 000000000..0f1264ebf --- /dev/null +++ b/authentik/providers/oauth2/tests/utils.py @@ -0,0 +1,31 @@ +"""OAuth test helpers""" +from django.test import TestCase +from jwt import decode + +from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken + + +class OAuthTestCase(TestCase): + """OAuth test helpers""" + + required_jwt_keys = [ + "exp", + "iat", + "auth_time", + "acr", + "sub", + "iss", + ] + + def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): + """Validate that all required fields are set""" + jwt = decode( + token.access_token, + provider.client_secret, + algorithms=[provider.jwt_alg], + audience=provider.client_id, + ) + id_token = token.id_token.to_dict() + for key in self.required_jwt_keys: + self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") + self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index dc71beec9..c0d85a345 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -16,6 +16,7 @@ from authentik.providers.oauth2.constants import ( from authentik.providers.oauth2.errors import TokenError, UserAuthError from authentik.providers.oauth2.models import ( AuthorizationCode, + ClientTypes, OAuth2Provider, RefreshToken, ) @@ -75,7 +76,7 @@ class TokenParams: LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id) raise TokenError("invalid_client") - if self.provider.client_type == "confidential": + if self.provider.client_type == ClientTypes.CONFIDENTIAL: if self.provider.client_secret != self.client_secret: LOGGER.warning( "Invalid client secret: client does not have secret",