diff --git a/authentik/events/signals.py b/authentik/events/signals.py index 1d20a1805..712b63e04 100644 --- a/authentik/events/signals.py +++ b/authentik/events/signals.py @@ -39,6 +39,11 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_): request.session[SESSION_LOGIN_EVENT] = event +def get_login_event(request: HttpRequest) -> Optional[Event]: + """Wrapper to get login event that can be mocked in tests""" + return request.session.get(SESSION_LOGIN_EVENT, None) + + @receiver(user_logged_out) # pylint: disable=unused-argument def on_user_logged_out(sender, request: HttpRequest, user: User, **_): diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index a4e402cbc..76a9bd0ac 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -22,8 +22,7 @@ from rest_framework.serializers import Serializer from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User from authentik.crypto.models import CertificateKeyPair -from authentik.events.models import Event -from authentik.events.signals import SESSION_LOGIN_EVENT +from authentik.events.signals import get_login_event from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key from authentik.lib.models import SerializerModel from authentik.lib.utils.time import timedelta_string_validator @@ -419,6 +418,8 @@ class IDToken: id_dict.pop("nonce") if not self.c_hash: id_dict.pop("c_hash") + if not self.amr: + id_dict.pop("amr") id_dict.pop("claims") id_dict.update(self.claims) return id_dict @@ -503,8 +504,8 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time # Fallback in case we can't find any login events auth_time = now - if SESSION_LOGIN_EVENT in request.session: - auth_event: Event = request.session[SESSION_LOGIN_EVENT] + auth_event = get_login_event(request) + if auth_event: auth_time = auth_event.created # Also check which method was used for authentication method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") @@ -526,6 +527,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): exp=exp_time, iat=iat_time, auth_time=auth_timestamp, + amr=amr if amr else None, ) # Include (or not) user standard claims in the id_token. diff --git a/authentik/providers/oauth2/tests/test_authorize.py b/authentik/providers/oauth2/tests/test_authorize.py index c06861ebd..89c3f93ea 100644 --- a/authentik/providers/oauth2/tests/test_authorize.py +++ b/authentik/providers/oauth2/tests/test_authorize.py @@ -1,10 +1,13 @@ """Test authorize view""" +from unittest.mock import MagicMock, patch + from django.test import RequestFactory from django.urls import reverse from django.utils.timezone import now from authentik.core.models import Application from authentik.core.tests.utils import create_test_admin_user, create_test_flow +from authentik.events.models import Event, EventAction from authentik.flows.challenge import ChallengeTypes from authentik.lib.generators import generate_id, generate_key from authentik.lib.utils.time import timedelta_from_string @@ -17,6 +20,7 @@ from authentik.providers.oauth2.models import ( ) from authentik.providers.oauth2.tests.utils import OAuthTestCase from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams +from authentik.stages.password.stage import PLAN_CONTEXT_METHOD class TestAuthorize(OAuthTestCase): @@ -302,40 +306,51 @@ class TestAuthorize(OAuthTestCase): state = generate_id() user = create_test_admin_user() self.client.force_login(user) - # Step 1, initiate params and get redirect to flow - self.client.get( - reverse("authentik_providers_oauth2:authorize"), - data={ - "response_type": "id_token", - "client_id": "test", - "state": state, - "scope": "openid", - "redirect_uri": "http://localhost", - }, - ) - response = self.client.get( - reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), - ) - token: RefreshToken = RefreshToken.objects.filter(user=user).first() - expires = timedelta_from_string(provider.access_code_validity).total_seconds() - self.assertJSONEqual( - response.content.decode(), - { - "component": "xak-flow-redirect", - "type": ChallengeTypes.REDIRECT.value, - "to": ( - f"http://localhost#access_token={token.access_token}" - f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer" - f"&expires_in={int(expires)}&state={state}" - ), - }, - ) - jwt = self.validate_jwt(token, provider) - self.assertAlmostEqual( - jwt["exp"] - now().timestamp(), - expires, - delta=5, - ) + with patch( + "authentik.providers.oauth2.models.get_login_event", + MagicMock( + return_value=Event( + action=EventAction.LOGIN, + context={PLAN_CONTEXT_METHOD: "password"}, + created=now(), + ) + ), + ): + # Step 1, initiate params and get redirect to flow + self.client.get( + reverse("authentik_providers_oauth2:authorize"), + data={ + "response_type": "id_token", + "client_id": "test", + "state": state, + "scope": "openid", + "redirect_uri": "http://localhost", + }, + ) + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), + ) + token: RefreshToken = RefreshToken.objects.filter(user=user).first() + expires = timedelta_from_string(provider.access_code_validity).total_seconds() + self.assertJSONEqual( + response.content.decode(), + { + "component": "xak-flow-redirect", + "type": ChallengeTypes.REDIRECT.value, + "to": ( + f"http://localhost#access_token={token.access_token}" + f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer" + f"&expires_in={int(expires)}&state={state}" + ), + }, + ) + jwt = self.validate_jwt(token, provider) + self.assertEqual(jwt["amr"], ["pwd"]) + self.assertAlmostEqual( + jwt["exp"] - now().timestamp(), + expires, + delta=5, + ) def test_full_form_post_id_token(self): """Test full authorization (form_post response)""" diff --git a/authentik/providers/oauth2/tests/utils.py b/authentik/providers/oauth2/tests/utils.py index bfa45324e..7801ddbe6 100644 --- a/authentik/providers/oauth2/tests/utils.py +++ b/authentik/providers/oauth2/tests/utils.py @@ -27,6 +27,11 @@ class OAuthTestCase(TestCase): cls.keypair = create_test_cert() super().setUpClass() + def assert_non_none_or_unset(self, container: dict, key: str): + """Check that a key, if set, is not none""" + if key in container: + self.assertIsNotNone(container[key]) + def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]: """Validate that all required fields are set""" key, alg = provider.jwt_key @@ -39,6 +44,10 @@ class OAuthTestCase(TestCase): audience=provider.client_id, ) id_token = token.id_token.to_dict() + self.assert_non_none_or_unset(id_token, "at_hash") + self.assert_non_none_or_unset(id_token, "nonce") + self.assert_non_none_or_unset(id_token, "c_hash") + self.assert_non_none_or_unset(id_token, "amr") for key in self.required_jwt_keys: self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") diff --git a/authentik/providers/saml/processors/assertion.py b/authentik/providers/saml/processors/assertion.py index 05f337ec4..a0f33c382 100644 --- a/authentik/providers/saml/processors/assertion.py +++ b/authentik/providers/saml/processors/assertion.py @@ -10,7 +10,7 @@ from structlog.stdlib import get_logger from authentik.core.exceptions import PropertyMappingExpressionException from authentik.events.models import Event, EventAction -from authentik.events.signals import SESSION_LOGIN_EVENT +from authentik.events.signals import get_login_event from authentik.lib.utils.time import timedelta_from_string from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider from authentik.providers.saml.processors.request_parser import AuthNRequest @@ -132,8 +132,8 @@ class AssertionProcessor: auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef" ) auth_n_context_class_ref.text = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified" - if SESSION_LOGIN_EVENT in self.http_request.session: - event: Event = self.http_request.session[SESSION_LOGIN_EVENT] + event = get_login_event(self.http_request) + if event: method = event.context.get(PLAN_CONTEXT_METHOD, "") method_args = event.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) if method == "password":