providers/oauth2: fix null amr value not being removed from id_token
closes #4339 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
57400925a4
commit
4b93f40c5e
|
@ -39,6 +39,11 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
|
||||||
request.session[SESSION_LOGIN_EVENT] = event
|
request.session[SESSION_LOGIN_EVENT] = event
|
||||||
|
|
||||||
|
|
||||||
|
def get_login_event(request: HttpRequest) -> Optional[Event]:
|
||||||
|
"""Wrapper to get login event that can be mocked in tests"""
|
||||||
|
return request.session.get(SESSION_LOGIN_EVENT, None)
|
||||||
|
|
||||||
|
|
||||||
@receiver(user_logged_out)
|
@receiver(user_logged_out)
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
|
def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
|
||||||
|
|
|
@ -22,8 +22,7 @@ from rest_framework.serializers import Serializer
|
||||||
|
|
||||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
|
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
|
||||||
from authentik.crypto.models import CertificateKeyPair
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
from authentik.events.models import Event
|
from authentik.events.signals import get_login_event
|
||||||
from authentik.events.signals import SESSION_LOGIN_EVENT
|
|
||||||
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
|
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
|
||||||
from authentik.lib.models import SerializerModel
|
from authentik.lib.models import SerializerModel
|
||||||
from authentik.lib.utils.time import timedelta_string_validator
|
from authentik.lib.utils.time import timedelta_string_validator
|
||||||
|
@ -419,6 +418,8 @@ class IDToken:
|
||||||
id_dict.pop("nonce")
|
id_dict.pop("nonce")
|
||||||
if not self.c_hash:
|
if not self.c_hash:
|
||||||
id_dict.pop("c_hash")
|
id_dict.pop("c_hash")
|
||||||
|
if not self.amr:
|
||||||
|
id_dict.pop("amr")
|
||||||
id_dict.pop("claims")
|
id_dict.pop("claims")
|
||||||
id_dict.update(self.claims)
|
id_dict.update(self.claims)
|
||||||
return id_dict
|
return id_dict
|
||||||
|
@ -503,8 +504,8 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
|
||||||
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
|
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
|
||||||
# Fallback in case we can't find any login events
|
# Fallback in case we can't find any login events
|
||||||
auth_time = now
|
auth_time = now
|
||||||
if SESSION_LOGIN_EVENT in request.session:
|
auth_event = get_login_event(request)
|
||||||
auth_event: Event = request.session[SESSION_LOGIN_EVENT]
|
if auth_event:
|
||||||
auth_time = auth_event.created
|
auth_time = auth_event.created
|
||||||
# Also check which method was used for authentication
|
# Also check which method was used for authentication
|
||||||
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
|
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
|
||||||
|
@ -526,6 +527,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
|
||||||
exp=exp_time,
|
exp=exp_time,
|
||||||
iat=iat_time,
|
iat=iat_time,
|
||||||
auth_time=auth_timestamp,
|
auth_time=auth_timestamp,
|
||||||
|
amr=amr if amr else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Include (or not) user standard claims in the id_token.
|
# Include (or not) user standard claims in the id_token.
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
"""Test authorize view"""
|
"""Test authorize view"""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from django.test import RequestFactory
|
from django.test import RequestFactory
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
|
||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||||
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.flows.challenge import ChallengeTypes
|
from authentik.flows.challenge import ChallengeTypes
|
||||||
from authentik.lib.generators import generate_id, generate_key
|
from authentik.lib.generators import generate_id, generate_key
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
|
@ -17,6 +20,7 @@ from authentik.providers.oauth2.models import (
|
||||||
)
|
)
|
||||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||||
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
|
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
|
||||||
|
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD
|
||||||
|
|
||||||
|
|
||||||
class TestAuthorize(OAuthTestCase):
|
class TestAuthorize(OAuthTestCase):
|
||||||
|
@ -302,40 +306,51 @@ class TestAuthorize(OAuthTestCase):
|
||||||
state = generate_id()
|
state = generate_id()
|
||||||
user = create_test_admin_user()
|
user = create_test_admin_user()
|
||||||
self.client.force_login(user)
|
self.client.force_login(user)
|
||||||
# Step 1, initiate params and get redirect to flow
|
with patch(
|
||||||
self.client.get(
|
"authentik.providers.oauth2.models.get_login_event",
|
||||||
reverse("authentik_providers_oauth2:authorize"),
|
MagicMock(
|
||||||
data={
|
return_value=Event(
|
||||||
"response_type": "id_token",
|
action=EventAction.LOGIN,
|
||||||
"client_id": "test",
|
context={PLAN_CONTEXT_METHOD: "password"},
|
||||||
"state": state,
|
created=now(),
|
||||||
"scope": "openid",
|
)
|
||||||
"redirect_uri": "http://localhost",
|
),
|
||||||
},
|
):
|
||||||
)
|
# Step 1, initiate params and get redirect to flow
|
||||||
response = self.client.get(
|
self.client.get(
|
||||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
reverse("authentik_providers_oauth2:authorize"),
|
||||||
)
|
data={
|
||||||
token: RefreshToken = RefreshToken.objects.filter(user=user).first()
|
"response_type": "id_token",
|
||||||
expires = timedelta_from_string(provider.access_code_validity).total_seconds()
|
"client_id": "test",
|
||||||
self.assertJSONEqual(
|
"state": state,
|
||||||
response.content.decode(),
|
"scope": "openid",
|
||||||
{
|
"redirect_uri": "http://localhost",
|
||||||
"component": "xak-flow-redirect",
|
},
|
||||||
"type": ChallengeTypes.REDIRECT.value,
|
)
|
||||||
"to": (
|
response = self.client.get(
|
||||||
f"http://localhost#access_token={token.access_token}"
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||||
f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer"
|
)
|
||||||
f"&expires_in={int(expires)}&state={state}"
|
token: RefreshToken = RefreshToken.objects.filter(user=user).first()
|
||||||
),
|
expires = timedelta_from_string(provider.access_code_validity).total_seconds()
|
||||||
},
|
self.assertJSONEqual(
|
||||||
)
|
response.content.decode(),
|
||||||
jwt = self.validate_jwt(token, provider)
|
{
|
||||||
self.assertAlmostEqual(
|
"component": "xak-flow-redirect",
|
||||||
jwt["exp"] - now().timestamp(),
|
"type": ChallengeTypes.REDIRECT.value,
|
||||||
expires,
|
"to": (
|
||||||
delta=5,
|
f"http://localhost#access_token={token.access_token}"
|
||||||
)
|
f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer"
|
||||||
|
f"&expires_in={int(expires)}&state={state}"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
jwt = self.validate_jwt(token, provider)
|
||||||
|
self.assertEqual(jwt["amr"], ["pwd"])
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
jwt["exp"] - now().timestamp(),
|
||||||
|
expires,
|
||||||
|
delta=5,
|
||||||
|
)
|
||||||
|
|
||||||
def test_full_form_post_id_token(self):
|
def test_full_form_post_id_token(self):
|
||||||
"""Test full authorization (form_post response)"""
|
"""Test full authorization (form_post response)"""
|
||||||
|
|
|
@ -27,6 +27,11 @@ class OAuthTestCase(TestCase):
|
||||||
cls.keypair = create_test_cert()
|
cls.keypair = create_test_cert()
|
||||||
super().setUpClass()
|
super().setUpClass()
|
||||||
|
|
||||||
|
def assert_non_none_or_unset(self, container: dict, key: str):
|
||||||
|
"""Check that a key, if set, is not none"""
|
||||||
|
if key in container:
|
||||||
|
self.assertIsNotNone(container[key])
|
||||||
|
|
||||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
|
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
|
||||||
"""Validate that all required fields are set"""
|
"""Validate that all required fields are set"""
|
||||||
key, alg = provider.jwt_key
|
key, alg = provider.jwt_key
|
||||||
|
@ -39,6 +44,10 @@ class OAuthTestCase(TestCase):
|
||||||
audience=provider.client_id,
|
audience=provider.client_id,
|
||||||
)
|
)
|
||||||
id_token = token.id_token.to_dict()
|
id_token = token.id_token.to_dict()
|
||||||
|
self.assert_non_none_or_unset(id_token, "at_hash")
|
||||||
|
self.assert_non_none_or_unset(id_token, "nonce")
|
||||||
|
self.assert_non_none_or_unset(id_token, "c_hash")
|
||||||
|
self.assert_non_none_or_unset(id_token, "amr")
|
||||||
for key in self.required_jwt_keys:
|
for key in self.required_jwt_keys:
|
||||||
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
|
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
|
||||||
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
|
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
|
||||||
|
|
|
@ -10,7 +10,7 @@ from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.events.signals import SESSION_LOGIN_EVENT
|
from authentik.events.signals import get_login_event
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
||||||
from authentik.providers.saml.processors.request_parser import AuthNRequest
|
from authentik.providers.saml.processors.request_parser import AuthNRequest
|
||||||
|
@ -132,8 +132,8 @@ class AssertionProcessor:
|
||||||
auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef"
|
auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef"
|
||||||
)
|
)
|
||||||
auth_n_context_class_ref.text = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
|
auth_n_context_class_ref.text = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
|
||||||
if SESSION_LOGIN_EVENT in self.http_request.session:
|
event = get_login_event(self.http_request)
|
||||||
event: Event = self.http_request.session[SESSION_LOGIN_EVENT]
|
if event:
|
||||||
method = event.context.get(PLAN_CONTEXT_METHOD, "")
|
method = event.context.get(PLAN_CONTEXT_METHOD, "")
|
||||||
method_args = event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
|
method_args = event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||||
if method == "password":
|
if method == "password":
|
||||||
|
|
Reference in a new issue