providers/oauth2: fix null amr value not being removed from id_token
closes #4339 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
57400925a4
commit
4b93f40c5e
|
@ -39,6 +39,11 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
|
|||
request.session[SESSION_LOGIN_EVENT] = event
|
||||
|
||||
|
||||
def get_login_event(request: HttpRequest) -> Optional[Event]:
|
||||
"""Wrapper to get login event that can be mocked in tests"""
|
||||
return request.session.get(SESSION_LOGIN_EVENT, None)
|
||||
|
||||
|
||||
@receiver(user_logged_out)
|
||||
# pylint: disable=unused-argument
|
||||
def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
|
||||
|
|
|
@ -22,8 +22,7 @@ from rest_framework.serializers import Serializer
|
|||
|
||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.models import Event
|
||||
from authentik.events.signals import SESSION_LOGIN_EVENT
|
||||
from authentik.events.signals import get_login_event
|
||||
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
|
@ -419,6 +418,8 @@ class IDToken:
|
|||
id_dict.pop("nonce")
|
||||
if not self.c_hash:
|
||||
id_dict.pop("c_hash")
|
||||
if not self.amr:
|
||||
id_dict.pop("amr")
|
||||
id_dict.pop("claims")
|
||||
id_dict.update(self.claims)
|
||||
return id_dict
|
||||
|
@ -503,8 +504,8 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
|
|||
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
|
||||
# Fallback in case we can't find any login events
|
||||
auth_time = now
|
||||
if SESSION_LOGIN_EVENT in request.session:
|
||||
auth_event: Event = request.session[SESSION_LOGIN_EVENT]
|
||||
auth_event = get_login_event(request)
|
||||
if auth_event:
|
||||
auth_time = auth_event.created
|
||||
# Also check which method was used for authentication
|
||||
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
|
||||
|
@ -526,6 +527,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
|
|||
exp=exp_time,
|
||||
iat=iat_time,
|
||||
auth_time=auth_timestamp,
|
||||
amr=amr if amr else None,
|
||||
)
|
||||
|
||||
# Include (or not) user standard claims in the id_token.
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
"""Test authorize view"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.challenge import ChallengeTypes
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
|
@ -17,6 +20,7 @@ from authentik.providers.oauth2.models import (
|
|||
)
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD
|
||||
|
||||
|
||||
class TestAuthorize(OAuthTestCase):
|
||||
|
@ -302,6 +306,16 @@ class TestAuthorize(OAuthTestCase):
|
|||
state = generate_id()
|
||||
user = create_test_admin_user()
|
||||
self.client.force_login(user)
|
||||
with patch(
|
||||
"authentik.providers.oauth2.models.get_login_event",
|
||||
MagicMock(
|
||||
return_value=Event(
|
||||
action=EventAction.LOGIN,
|
||||
context={PLAN_CONTEXT_METHOD: "password"},
|
||||
created=now(),
|
||||
)
|
||||
),
|
||||
):
|
||||
# Step 1, initiate params and get redirect to flow
|
||||
self.client.get(
|
||||
reverse("authentik_providers_oauth2:authorize"),
|
||||
|
@ -331,6 +345,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
},
|
||||
)
|
||||
jwt = self.validate_jwt(token, provider)
|
||||
self.assertEqual(jwt["amr"], ["pwd"])
|
||||
self.assertAlmostEqual(
|
||||
jwt["exp"] - now().timestamp(),
|
||||
expires,
|
||||
|
|
|
@ -27,6 +27,11 @@ class OAuthTestCase(TestCase):
|
|||
cls.keypair = create_test_cert()
|
||||
super().setUpClass()
|
||||
|
||||
def assert_non_none_or_unset(self, container: dict, key: str):
|
||||
"""Check that a key, if set, is not none"""
|
||||
if key in container:
|
||||
self.assertIsNotNone(container[key])
|
||||
|
||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
|
||||
"""Validate that all required fields are set"""
|
||||
key, alg = provider.jwt_key
|
||||
|
@ -39,6 +44,10 @@ class OAuthTestCase(TestCase):
|
|||
audience=provider.client_id,
|
||||
)
|
||||
id_token = token.id_token.to_dict()
|
||||
self.assert_non_none_or_unset(id_token, "at_hash")
|
||||
self.assert_non_none_or_unset(id_token, "nonce")
|
||||
self.assert_non_none_or_unset(id_token, "c_hash")
|
||||
self.assert_non_none_or_unset(id_token, "amr")
|
||||
for key in self.required_jwt_keys:
|
||||
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
|
||||
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
|
||||
|
|
|
@ -10,7 +10,7 @@ from structlog.stdlib import get_logger
|
|||
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.signals import SESSION_LOGIN_EVENT
|
||||
from authentik.events.signals import get_login_event
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
||||
from authentik.providers.saml.processors.request_parser import AuthNRequest
|
||||
|
@ -132,8 +132,8 @@ class AssertionProcessor:
|
|||
auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef"
|
||||
)
|
||||
auth_n_context_class_ref.text = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
|
||||
if SESSION_LOGIN_EVENT in self.http_request.session:
|
||||
event: Event = self.http_request.session[SESSION_LOGIN_EVENT]
|
||||
event = get_login_event(self.http_request)
|
||||
if event:
|
||||
method = event.context.get(PLAN_CONTEXT_METHOD, "")
|
||||
method_args = event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||
if method == "password":
|
||||
|
|
Reference in a new issue