providers/oauth2: always test JWT keys in tests

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-27 14:07:04 +02:00
parent 972471ce79
commit fe28d216fe
4 changed files with 51 additions and 28 deletions

View file

@ -1,8 +1,7 @@
"""Test authorize view""" """Test authorize view"""
from django.test import RequestFactory, TestCase from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from django.utils.encoding import force_str from django.utils.encoding import force_str
from jwt import decode
from authentik.core.models import Application, User from authentik.core.models import Application, User
from authentik.flows.challenge import ChallengeTypes from authentik.flows.challenge import ChallengeTypes
@ -22,10 +21,11 @@ from authentik.providers.oauth2.models import (
OAuth2Provider, OAuth2Provider,
RefreshToken, RefreshToken,
) )
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
class TestAuthorize(TestCase): class TestAuthorize(OAuthTestCase):
"""Test authorize view""" """Test authorize view"""
def setUp(self) -> None: def setUp(self) -> None:
@ -238,23 +238,4 @@ class TestAuthorize(TestCase):
), ),
}, },
) )
jwt = decode( self.validate_jwt(token, provider)
token.access_token,
provider.client_secret,
algorithms=[provider.jwt_alg],
audience=provider.client_id,
)
self.assertIsNotNone(jwt["exp"])
self.assertIsNotNone(jwt["iat"])
self.assertIsNotNone(jwt["auth_time"])
self.assertIsNotNone(jwt["acr"])
self.assertIsNotNone(jwt["sub"])
self.assertIsNotNone(jwt["iss"])
# Check id_token
id_token = token.id_token.to_dict()
self.assertIsNotNone(id_token["exp"])
self.assertIsNotNone(id_token["iat"])
self.assertIsNotNone(id_token["auth_time"])
self.assertIsNotNone(id_token["acr"])
self.assertIsNotNone(id_token["sub"])
self.assertIsNotNone(id_token["iss"])

View file

@ -1,11 +1,11 @@
"""Test token view""" """Test token view"""
from base64 import b64encode from base64 import b64encode
from django.test import RequestFactory, TestCase from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from django.utils.encoding import force_str from django.utils.encoding import force_str
from authentik.core.models import User from authentik.core.models import Application, User
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.providers.oauth2.constants import ( from authentik.providers.oauth2.constants import (
GRANT_TYPE_AUTHORIZATION_CODE, GRANT_TYPE_AUTHORIZATION_CODE,
@ -20,15 +20,17 @@ from authentik.providers.oauth2.models import (
OAuth2Provider, OAuth2Provider,
RefreshToken, RefreshToken,
) )
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.token import TokenParams from authentik.providers.oauth2.views.token import TokenParams
class TestToken(TestCase): class TestToken(OAuthTestCase):
"""Test token view""" """Test token view"""
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.factory = RequestFactory() self.factory = RequestFactory()
self.app = Application.objects.create(name="test", slug="test")
def test_request_auth_code(self): def test_request_auth_code(self):
"""test request param""" """test request param"""
@ -97,12 +99,15 @@ class TestToken(TestCase):
authorization_flow=Flow.objects.first(), authorization_flow=Flow.objects.first(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
) )
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
self.app.save()
header = b64encode( header = b64encode(
f"{provider.client_id}:{provider.client_secret}".encode() f"{provider.client_id}:{provider.client_secret}".encode()
).decode() ).decode()
user = User.objects.get(username="akadmin") user = User.objects.get(username="akadmin")
code = AuthorizationCode.objects.create( code = AuthorizationCode.objects.create(
code="foobar", provider=provider, user=user code="foobar", provider=provider, user=user, is_open_id=True
) )
response = self.client.post( response = self.client.post(
reverse("authentik_providers_oauth2:token"), reverse("authentik_providers_oauth2:token"),
@ -126,6 +131,7 @@ class TestToken(TestCase):
), ),
}, },
) )
self.validate_jwt(new_token, provider)
def test_refresh_token_view(self): def test_refresh_token_view(self):
"""test request param""" """test request param"""
@ -136,6 +142,9 @@ class TestToken(TestCase):
authorization_flow=Flow.objects.first(), authorization_flow=Flow.objects.first(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
) )
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
self.app.save()
header = b64encode( header = b64encode(
f"{provider.client_id}:{provider.client_secret}".encode() f"{provider.client_id}:{provider.client_secret}".encode()
).decode() ).decode()
@ -174,6 +183,7 @@ class TestToken(TestCase):
), ),
}, },
) )
self.validate_jwt(new_token, provider)
def test_refresh_token_view_invalid_origin(self): def test_refresh_token_view_invalid_origin(self):
"""test request param""" """test request param"""

View file

@ -0,0 +1,31 @@
"""OAuth test helpers"""
from django.test import TestCase
from jwt import decode
from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken
class OAuthTestCase(TestCase):
"""OAuth test helpers"""
required_jwt_keys = [
"exp",
"iat",
"auth_time",
"acr",
"sub",
"iss",
]
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
"""Validate that all required fields are set"""
jwt = decode(
token.access_token,
provider.client_secret,
algorithms=[provider.jwt_alg],
audience=provider.client_id,
)
id_token = token.id_token.to_dict()
for key in self.required_jwt_keys:
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")

View file

@ -16,6 +16,7 @@ from authentik.providers.oauth2.constants import (
from authentik.providers.oauth2.errors import TokenError, UserAuthError from authentik.providers.oauth2.errors import TokenError, UserAuthError
from authentik.providers.oauth2.models import ( from authentik.providers.oauth2.models import (
AuthorizationCode, AuthorizationCode,
ClientTypes,
OAuth2Provider, OAuth2Provider,
RefreshToken, RefreshToken,
) )
@ -75,7 +76,7 @@ class TokenParams:
LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id) LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
raise TokenError("invalid_client") raise TokenError("invalid_client")
if self.provider.client_type == "confidential": if self.provider.client_type == ClientTypes.CONFIDENTIAL:
if self.provider.client_secret != self.client_secret: if self.provider.client_secret != self.client_secret:
LOGGER.warning( LOGGER.warning(
"Invalid client secret: client does not have secret", "Invalid client secret: client does not have secret",