providers/oauth2: always test JWT keys in tests
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
972471ce79
commit
fe28d216fe
|
@ -1,8 +1,7 @@
|
||||||
"""Test authorize view"""
|
"""Test authorize view"""
|
||||||
from django.test import RequestFactory, TestCase
|
from django.test import RequestFactory
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
from jwt import decode
|
|
||||||
|
|
||||||
from authentik.core.models import Application, User
|
from authentik.core.models import Application, User
|
||||||
from authentik.flows.challenge import ChallengeTypes
|
from authentik.flows.challenge import ChallengeTypes
|
||||||
|
@ -22,10 +21,11 @@ from authentik.providers.oauth2.models import (
|
||||||
OAuth2Provider,
|
OAuth2Provider,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
)
|
)
|
||||||
|
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||||
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
|
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
|
||||||
|
|
||||||
|
|
||||||
class TestAuthorize(TestCase):
|
class TestAuthorize(OAuthTestCase):
|
||||||
"""Test authorize view"""
|
"""Test authorize view"""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
@ -238,23 +238,4 @@ class TestAuthorize(TestCase):
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
jwt = decode(
|
self.validate_jwt(token, provider)
|
||||||
token.access_token,
|
|
||||||
provider.client_secret,
|
|
||||||
algorithms=[provider.jwt_alg],
|
|
||||||
audience=provider.client_id,
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(jwt["exp"])
|
|
||||||
self.assertIsNotNone(jwt["iat"])
|
|
||||||
self.assertIsNotNone(jwt["auth_time"])
|
|
||||||
self.assertIsNotNone(jwt["acr"])
|
|
||||||
self.assertIsNotNone(jwt["sub"])
|
|
||||||
self.assertIsNotNone(jwt["iss"])
|
|
||||||
# Check id_token
|
|
||||||
id_token = token.id_token.to_dict()
|
|
||||||
self.assertIsNotNone(id_token["exp"])
|
|
||||||
self.assertIsNotNone(id_token["iat"])
|
|
||||||
self.assertIsNotNone(id_token["auth_time"])
|
|
||||||
self.assertIsNotNone(id_token["acr"])
|
|
||||||
self.assertIsNotNone(id_token["sub"])
|
|
||||||
self.assertIsNotNone(id_token["iss"])
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
"""Test token view"""
|
"""Test token view"""
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
|
||||||
from django.test import RequestFactory, TestCase
|
from django.test import RequestFactory
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
|
||||||
from authentik.core.models import User
|
from authentik.core.models import Application, User
|
||||||
from authentik.flows.models import Flow
|
from authentik.flows.models import Flow
|
||||||
from authentik.providers.oauth2.constants import (
|
from authentik.providers.oauth2.constants import (
|
||||||
GRANT_TYPE_AUTHORIZATION_CODE,
|
GRANT_TYPE_AUTHORIZATION_CODE,
|
||||||
|
@ -20,15 +20,17 @@ from authentik.providers.oauth2.models import (
|
||||||
OAuth2Provider,
|
OAuth2Provider,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
)
|
)
|
||||||
|
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||||
from authentik.providers.oauth2.views.token import TokenParams
|
from authentik.providers.oauth2.views.token import TokenParams
|
||||||
|
|
||||||
|
|
||||||
class TestToken(TestCase):
|
class TestToken(OAuthTestCase):
|
||||||
"""Test token view"""
|
"""Test token view"""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.factory = RequestFactory()
|
self.factory = RequestFactory()
|
||||||
|
self.app = Application.objects.create(name="test", slug="test")
|
||||||
|
|
||||||
def test_request_auth_code(self):
|
def test_request_auth_code(self):
|
||||||
"""test request param"""
|
"""test request param"""
|
||||||
|
@ -97,12 +99,15 @@ class TestToken(TestCase):
|
||||||
authorization_flow=Flow.objects.first(),
|
authorization_flow=Flow.objects.first(),
|
||||||
redirect_uris="http://local.invalid",
|
redirect_uris="http://local.invalid",
|
||||||
)
|
)
|
||||||
|
# Needs to be assigned to an application for iss to be set
|
||||||
|
self.app.provider = provider
|
||||||
|
self.app.save()
|
||||||
header = b64encode(
|
header = b64encode(
|
||||||
f"{provider.client_id}:{provider.client_secret}".encode()
|
f"{provider.client_id}:{provider.client_secret}".encode()
|
||||||
).decode()
|
).decode()
|
||||||
user = User.objects.get(username="akadmin")
|
user = User.objects.get(username="akadmin")
|
||||||
code = AuthorizationCode.objects.create(
|
code = AuthorizationCode.objects.create(
|
||||||
code="foobar", provider=provider, user=user
|
code="foobar", provider=provider, user=user, is_open_id=True
|
||||||
)
|
)
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse("authentik_providers_oauth2:token"),
|
reverse("authentik_providers_oauth2:token"),
|
||||||
|
@ -126,6 +131,7 @@ class TestToken(TestCase):
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_jwt(new_token, provider)
|
||||||
|
|
||||||
def test_refresh_token_view(self):
|
def test_refresh_token_view(self):
|
||||||
"""test request param"""
|
"""test request param"""
|
||||||
|
@ -136,6 +142,9 @@ class TestToken(TestCase):
|
||||||
authorization_flow=Flow.objects.first(),
|
authorization_flow=Flow.objects.first(),
|
||||||
redirect_uris="http://local.invalid",
|
redirect_uris="http://local.invalid",
|
||||||
)
|
)
|
||||||
|
# Needs to be assigned to an application for iss to be set
|
||||||
|
self.app.provider = provider
|
||||||
|
self.app.save()
|
||||||
header = b64encode(
|
header = b64encode(
|
||||||
f"{provider.client_id}:{provider.client_secret}".encode()
|
f"{provider.client_id}:{provider.client_secret}".encode()
|
||||||
).decode()
|
).decode()
|
||||||
|
@ -174,6 +183,7 @@ class TestToken(TestCase):
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_jwt(new_token, provider)
|
||||||
|
|
||||||
def test_refresh_token_view_invalid_origin(self):
|
def test_refresh_token_view_invalid_origin(self):
|
||||||
"""test request param"""
|
"""test request param"""
|
||||||
|
|
31
authentik/providers/oauth2/tests/utils.py
Normal file
31
authentik/providers/oauth2/tests/utils.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
"""OAuth test helpers"""
|
||||||
|
from django.test import TestCase
|
||||||
|
from jwt import decode
|
||||||
|
|
||||||
|
from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthTestCase(TestCase):
|
||||||
|
"""OAuth test helpers"""
|
||||||
|
|
||||||
|
required_jwt_keys = [
|
||||||
|
"exp",
|
||||||
|
"iat",
|
||||||
|
"auth_time",
|
||||||
|
"acr",
|
||||||
|
"sub",
|
||||||
|
"iss",
|
||||||
|
]
|
||||||
|
|
||||||
|
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
|
||||||
|
"""Validate that all required fields are set"""
|
||||||
|
jwt = decode(
|
||||||
|
token.access_token,
|
||||||
|
provider.client_secret,
|
||||||
|
algorithms=[provider.jwt_alg],
|
||||||
|
audience=provider.client_id,
|
||||||
|
)
|
||||||
|
id_token = token.id_token.to_dict()
|
||||||
|
for key in self.required_jwt_keys:
|
||||||
|
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
|
||||||
|
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
|
|
@ -16,6 +16,7 @@ from authentik.providers.oauth2.constants import (
|
||||||
from authentik.providers.oauth2.errors import TokenError, UserAuthError
|
from authentik.providers.oauth2.errors import TokenError, UserAuthError
|
||||||
from authentik.providers.oauth2.models import (
|
from authentik.providers.oauth2.models import (
|
||||||
AuthorizationCode,
|
AuthorizationCode,
|
||||||
|
ClientTypes,
|
||||||
OAuth2Provider,
|
OAuth2Provider,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
)
|
)
|
||||||
|
@ -75,7 +76,7 @@ class TokenParams:
|
||||||
LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
|
LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
|
||||||
raise TokenError("invalid_client")
|
raise TokenError("invalid_client")
|
||||||
|
|
||||||
if self.provider.client_type == "confidential":
|
if self.provider.client_type == ClientTypes.CONFIDENTIAL:
|
||||||
if self.provider.client_secret != self.client_secret:
|
if self.provider.client_secret != self.client_secret:
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
"Invalid client secret: client does not have secret",
|
"Invalid client secret: client does not have secret",
|
||||||
|
|
Reference in a new issue