providers/oauth2: fix null amr value not being removed from id_token

closes #4339

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2023-01-03 00:15:21 +01:00
parent 57400925a4
commit 4b93f40c5e
No known key found for this signature in database
5 changed files with 72 additions and 41 deletions

View file

@ -39,6 +39,11 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
request.session[SESSION_LOGIN_EVENT] = event request.session[SESSION_LOGIN_EVENT] = event
def get_login_event(request: HttpRequest) -> Optional[Event]:
"""Wrapper to get login event that can be mocked in tests"""
return request.session.get(SESSION_LOGIN_EVENT, None)
@receiver(user_logged_out) @receiver(user_logged_out)
# pylint: disable=unused-argument # pylint: disable=unused-argument
def on_user_logged_out(sender, request: HttpRequest, user: User, **_): def on_user_logged_out(sender, request: HttpRequest, user: User, **_):

View file

@ -22,8 +22,7 @@ from rest_framework.serializers import Serializer
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event from authentik.events.signals import get_login_event
from authentik.events.signals import SESSION_LOGIN_EVENT
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.time import timedelta_string_validator from authentik.lib.utils.time import timedelta_string_validator
@ -419,6 +418,8 @@ class IDToken:
id_dict.pop("nonce") id_dict.pop("nonce")
if not self.c_hash: if not self.c_hash:
id_dict.pop("c_hash") id_dict.pop("c_hash")
if not self.amr:
id_dict.pop("amr")
id_dict.pop("claims") id_dict.pop("claims")
id_dict.update(self.claims) id_dict.update(self.claims)
return id_dict return id_dict
@ -503,8 +504,8 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
# Fallback in case we can't find any login events # Fallback in case we can't find any login events
auth_time = now auth_time = now
if SESSION_LOGIN_EVENT in request.session: auth_event = get_login_event(request)
auth_event: Event = request.session[SESSION_LOGIN_EVENT] if auth_event:
auth_time = auth_event.created auth_time = auth_event.created
# Also check which method was used for authentication # Also check which method was used for authentication
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
@ -526,6 +527,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
exp=exp_time, exp=exp_time,
iat=iat_time, iat=iat_time,
auth_time=auth_timestamp, auth_time=auth_timestamp,
amr=amr if amr else None,
) )
# Include (or not) user standard claims in the id_token. # Include (or not) user standard claims in the id_token.

View file

@ -1,10 +1,13 @@
"""Test authorize view""" """Test authorize view"""
from unittest.mock import MagicMock, patch
from django.test import RequestFactory from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from django.utils.timezone import now from django.utils.timezone import now
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction
from authentik.flows.challenge import ChallengeTypes from authentik.flows.challenge import ChallengeTypes
from authentik.lib.generators import generate_id, generate_key from authentik.lib.generators import generate_id, generate_key
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
@ -17,6 +20,7 @@ from authentik.providers.oauth2.models import (
) )
from authentik.providers.oauth2.tests.utils import OAuthTestCase from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD
class TestAuthorize(OAuthTestCase): class TestAuthorize(OAuthTestCase):
@ -302,40 +306,51 @@ class TestAuthorize(OAuthTestCase):
state = generate_id() state = generate_id()
user = create_test_admin_user() user = create_test_admin_user()
self.client.force_login(user) self.client.force_login(user)
# Step 1, initiate params and get redirect to flow with patch(
self.client.get( "authentik.providers.oauth2.models.get_login_event",
reverse("authentik_providers_oauth2:authorize"), MagicMock(
data={ return_value=Event(
"response_type": "id_token", action=EventAction.LOGIN,
"client_id": "test", context={PLAN_CONTEXT_METHOD: "password"},
"state": state, created=now(),
"scope": "openid", )
"redirect_uri": "http://localhost", ),
}, ):
) # Step 1, initiate params and get redirect to flow
response = self.client.get( self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), reverse("authentik_providers_oauth2:authorize"),
) data={
token: RefreshToken = RefreshToken.objects.filter(user=user).first() "response_type": "id_token",
expires = timedelta_from_string(provider.access_code_validity).total_seconds() "client_id": "test",
self.assertJSONEqual( "state": state,
response.content.decode(), "scope": "openid",
{ "redirect_uri": "http://localhost",
"component": "xak-flow-redirect", },
"type": ChallengeTypes.REDIRECT.value, )
"to": ( response = self.client.get(
f"http://localhost#access_token={token.access_token}" reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer" )
f"&expires_in={int(expires)}&state={state}" token: RefreshToken = RefreshToken.objects.filter(user=user).first()
), expires = timedelta_from_string(provider.access_code_validity).total_seconds()
}, self.assertJSONEqual(
) response.content.decode(),
jwt = self.validate_jwt(token, provider) {
self.assertAlmostEqual( "component": "xak-flow-redirect",
jwt["exp"] - now().timestamp(), "type": ChallengeTypes.REDIRECT.value,
expires, "to": (
delta=5, f"http://localhost#access_token={token.access_token}"
) f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer"
f"&expires_in={int(expires)}&state={state}"
),
},
)
jwt = self.validate_jwt(token, provider)
self.assertEqual(jwt["amr"], ["pwd"])
self.assertAlmostEqual(
jwt["exp"] - now().timestamp(),
expires,
delta=5,
)
def test_full_form_post_id_token(self): def test_full_form_post_id_token(self):
"""Test full authorization (form_post response)""" """Test full authorization (form_post response)"""

View file

@ -27,6 +27,11 @@ class OAuthTestCase(TestCase):
cls.keypair = create_test_cert() cls.keypair = create_test_cert()
super().setUpClass() super().setUpClass()
def assert_non_none_or_unset(self, container: dict, key: str):
"""Check that a key, if set, is not none"""
if key in container:
self.assertIsNotNone(container[key])
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]: def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate that all required fields are set""" """Validate that all required fields are set"""
key, alg = provider.jwt_key key, alg = provider.jwt_key
@ -39,6 +44,10 @@ class OAuthTestCase(TestCase):
audience=provider.client_id, audience=provider.client_id,
) )
id_token = token.id_token.to_dict() id_token = token.id_token.to_dict()
self.assert_non_none_or_unset(id_token, "at_hash")
self.assert_non_none_or_unset(id_token, "nonce")
self.assert_non_none_or_unset(id_token, "c_hash")
self.assert_non_none_or_unset(id_token, "amr")
for key in self.required_jwt_keys: for key in self.required_jwt_keys:
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")

View file

@ -10,7 +10,7 @@ from structlog.stdlib import get_logger
from authentik.core.exceptions import PropertyMappingExpressionException from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.events.signals import SESSION_LOGIN_EVENT from authentik.events.signals import get_login_event
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
from authentik.providers.saml.processors.request_parser import AuthNRequest from authentik.providers.saml.processors.request_parser import AuthNRequest
@ -132,8 +132,8 @@ class AssertionProcessor:
auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef" auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef"
) )
auth_n_context_class_ref.text = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified" auth_n_context_class_ref.text = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
if SESSION_LOGIN_EVENT in self.http_request.session: event = get_login_event(self.http_request)
event: Event = self.http_request.session[SESSION_LOGIN_EVENT] if event:
method = event.context.get(PLAN_CONTEXT_METHOD, "") method = event.context.get(PLAN_CONTEXT_METHOD, "")
method_args = event.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) method_args = event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
if method == "password": if method == "password":