providers/oauth2: fix null amr value not being removed from id_token

closes #4339

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2023-01-03 00:15:21 +01:00
parent 57400925a4
commit 4b93f40c5e
No known key found for this signature in database
5 changed files with 72 additions and 41 deletions

View file

@ -39,6 +39,11 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
request.session[SESSION_LOGIN_EVENT] = event
def get_login_event(request: HttpRequest) -> Optional[Event]:
"""Wrapper to get login event that can be mocked in tests"""
return request.session.get(SESSION_LOGIN_EVENT, None)
@receiver(user_logged_out)
# pylint: disable=unused-argument
def on_user_logged_out(sender, request: HttpRequest, user: User, **_):

View file

@ -22,8 +22,7 @@ from rest_framework.serializers import Serializer
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event
from authentik.events.signals import SESSION_LOGIN_EVENT
from authentik.events.signals import get_login_event
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
from authentik.lib.models import SerializerModel
from authentik.lib.utils.time import timedelta_string_validator
@ -419,6 +418,8 @@ class IDToken:
id_dict.pop("nonce")
if not self.c_hash:
id_dict.pop("c_hash")
if not self.amr:
id_dict.pop("amr")
id_dict.pop("claims")
id_dict.update(self.claims)
return id_dict
@ -503,8 +504,8 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
# Fallback in case we can't find any login events
auth_time = now
if SESSION_LOGIN_EVENT in request.session:
auth_event: Event = request.session[SESSION_LOGIN_EVENT]
auth_event = get_login_event(request)
if auth_event:
auth_time = auth_event.created
# Also check which method was used for authentication
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
@ -526,6 +527,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
exp=exp_time,
iat=iat_time,
auth_time=auth_timestamp,
amr=amr if amr else None,
)
# Include (or not) user standard claims in the id_token.

View file

@ -1,10 +1,13 @@
"""Test authorize view"""
from unittest.mock import MagicMock, patch
from django.test import RequestFactory
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction
from authentik.flows.challenge import ChallengeTypes
from authentik.lib.generators import generate_id, generate_key
from authentik.lib.utils.time import timedelta_from_string
@ -17,6 +20,7 @@ from authentik.providers.oauth2.models import (
)
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD
class TestAuthorize(OAuthTestCase):
@ -302,40 +306,51 @@ class TestAuthorize(OAuthTestCase):
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "id_token",
"client_id": "test",
"state": state,
"scope": "openid",
"redirect_uri": "http://localhost",
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
token: RefreshToken = RefreshToken.objects.filter(user=user).first()
expires = timedelta_from_string(provider.access_code_validity).total_seconds()
self.assertJSONEqual(
response.content.decode(),
{
"component": "xak-flow-redirect",
"type": ChallengeTypes.REDIRECT.value,
"to": (
f"http://localhost#access_token={token.access_token}"
f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer"
f"&expires_in={int(expires)}&state={state}"
),
},
)
jwt = self.validate_jwt(token, provider)
self.assertAlmostEqual(
jwt["exp"] - now().timestamp(),
expires,
delta=5,
)
with patch(
"authentik.providers.oauth2.models.get_login_event",
MagicMock(
return_value=Event(
action=EventAction.LOGIN,
context={PLAN_CONTEXT_METHOD: "password"},
created=now(),
)
),
):
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "id_token",
"client_id": "test",
"state": state,
"scope": "openid",
"redirect_uri": "http://localhost",
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
token: RefreshToken = RefreshToken.objects.filter(user=user).first()
expires = timedelta_from_string(provider.access_code_validity).total_seconds()
self.assertJSONEqual(
response.content.decode(),
{
"component": "xak-flow-redirect",
"type": ChallengeTypes.REDIRECT.value,
"to": (
f"http://localhost#access_token={token.access_token}"
f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer"
f"&expires_in={int(expires)}&state={state}"
),
},
)
jwt = self.validate_jwt(token, provider)
self.assertEqual(jwt["amr"], ["pwd"])
self.assertAlmostEqual(
jwt["exp"] - now().timestamp(),
expires,
delta=5,
)
def test_full_form_post_id_token(self):
"""Test full authorization (form_post response)"""

View file

@ -27,6 +27,11 @@ class OAuthTestCase(TestCase):
cls.keypair = create_test_cert()
super().setUpClass()
def assert_non_none_or_unset(self, container: dict, key: str):
"""Check that a key, if set, is not none"""
if key in container:
self.assertIsNotNone(container[key])
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate that all required fields are set"""
key, alg = provider.jwt_key
@ -39,6 +44,10 @@ class OAuthTestCase(TestCase):
audience=provider.client_id,
)
id_token = token.id_token.to_dict()
self.assert_non_none_or_unset(id_token, "at_hash")
self.assert_non_none_or_unset(id_token, "nonce")
self.assert_non_none_or_unset(id_token, "c_hash")
self.assert_non_none_or_unset(id_token, "amr")
for key in self.required_jwt_keys:
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")

View file

@ -10,7 +10,7 @@ from structlog.stdlib import get_logger
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.events.models import Event, EventAction
from authentik.events.signals import SESSION_LOGIN_EVENT
from authentik.events.signals import get_login_event
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
from authentik.providers.saml.processors.request_parser import AuthNRequest
@ -132,8 +132,8 @@ class AssertionProcessor:
auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef"
)
auth_n_context_class_ref.text = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
if SESSION_LOGIN_EVENT in self.http_request.session:
event: Event = self.http_request.session[SESSION_LOGIN_EVENT]
event = get_login_event(self.http_request)
if event:
method = event.context.get(PLAN_CONTEXT_METHOD, "")
method_args = event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
if method == "password":