diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index ea40afd4a..9d22aac6b 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -2,9 +2,8 @@ import base64 import binascii import json -import time from dataclasses import asdict, dataclass, field -from datetime import datetime +from datetime import datetime, timedelta from hashlib import sha256 from typing import Any, Optional from urllib.parse import urlparse, urlunparse @@ -14,7 +13,7 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from dacite.core import from_dict from django.db import models from django.http import HttpRequest -from django.utils import dateformat, timezone +from django.utils import timezone from django.utils.translation import gettext_lazy as _ from jwt import encode from rest_framework.serializers import Serializer @@ -25,7 +24,7 @@ from authentik.events.models import Event, EventAction from authentik.events.utils import get_user from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key from authentik.lib.models import SerializerModel -from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator +from authentik.lib.utils.time import timedelta_string_validator from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT from authentik.sources.oauth.models import OAuthSource @@ -237,14 +236,18 @@ class OAuth2Provider(Provider): ) def create_refresh_token( - self, user: User, scope: list[str], request: HttpRequest + self, + user: User, + scope: list[str], + request: HttpRequest, + expiry: timedelta, ) -> "RefreshToken": """Create and populate a RefreshToken object.""" token = RefreshToken( user=user, provider=self, refresh_token=base64.urlsafe_b64encode(generate_key().encode()).decode(), - expires=timezone.now() + timedelta_from_string(self.token_validity), + expires=timezone.now() + expiry, scope=scope, ) token.access_token = token.create_access_token(user, request) @@ -484,18 +487,21 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): ) # Convert datetimes into timestamps. - now = int(time.time()) - iat_time = now - exp_time = int(dateformat.format(self.expires, "U")) + now = datetime.now() + iat_time = int(now.timestamp()) + exp_time = int(self.expires.timestamp()) # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time - auth_events = Event.objects.filter(action=EventAction.LOGIN, user=get_user(user)).order_by( - "-created" + auth_event = ( + Event.objects.filter(action=EventAction.LOGIN, user=get_user(user)) + .order_by("-created") + .first() ) # Fallback in case we can't find any login events - auth_time = datetime.now() - if auth_events.exists(): - auth_time = auth_events.first().created - auth_time = int(dateformat.format(auth_time, "U")) + auth_time = now + if auth_event: + auth_time = auth_event.created + + auth_timestamp = int(auth_time.timestamp()) token = IDToken( iss=self.provider.get_issuer(request), @@ -503,7 +509,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): aud=self.provider.client_id, exp=exp_time, iat=iat_time, - auth_time=auth_time, + auth_time=auth_timestamp, ) # Include (or not) user standard claims in the id_token. diff --git a/authentik/providers/oauth2/tests/test_authorize.py b/authentik/providers/oauth2/tests/test_authorize.py index 128f3fa15..c06861ebd 100644 --- a/authentik/providers/oauth2/tests/test_authorize.py +++ b/authentik/providers/oauth2/tests/test_authorize.py @@ -1,11 +1,13 @@ """Test authorize view""" from django.test import RequestFactory from django.urls import reverse +from django.utils.timezone import now from authentik.core.models import Application from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.flows.challenge import ChallengeTypes from authentik.lib.generators import generate_id, generate_key +from authentik.lib.utils.time import timedelta_from_string from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError from authentik.providers.oauth2.models import ( AuthorizationCode, @@ -250,6 +252,7 @@ class TestAuthorize(OAuthTestCase): client_id="test", authorization_flow=flow, redirect_uris="foo://localhost", + access_code_validity="seconds=100", ) Application.objects.create(name="app", slug="app", provider=provider) state = generate_id() @@ -277,6 +280,11 @@ class TestAuthorize(OAuthTestCase): "to": f"foo://localhost?code={code.code}&state={state}", }, ) + self.assertAlmostEqual( + code.expires.timestamp() - now().timestamp(), + timedelta_from_string(provider.access_code_validity).total_seconds(), + delta=5, + ) def test_full_implicit(self): """Test full authorization""" @@ -288,6 +296,7 @@ class TestAuthorize(OAuthTestCase): authorization_flow=flow, redirect_uris="http://localhost", signing_key=self.keypair, + access_code_validity="seconds=100", ) Application.objects.create(name="app", slug="app", provider=provider) state = generate_id() @@ -308,6 +317,7 @@ class TestAuthorize(OAuthTestCase): reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), ) token: RefreshToken = RefreshToken.objects.filter(user=user).first() + expires = timedelta_from_string(provider.access_code_validity).total_seconds() self.assertJSONEqual( response.content.decode(), { @@ -316,11 +326,16 @@ class TestAuthorize(OAuthTestCase): "to": ( f"http://localhost#access_token={token.access_token}" f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer" - f"&expires_in=60&state={state}" + f"&expires_in={int(expires)}&state={state}" ), }, ) - self.validate_jwt(token, provider) + jwt = self.validate_jwt(token, provider) + self.assertAlmostEqual( + jwt["exp"] - now().timestamp(), + expires, + delta=5, + ) def test_full_form_post_id_token(self): """Test full authorization (form_post response)""" diff --git a/authentik/providers/oauth2/tests/utils.py b/authentik/providers/oauth2/tests/utils.py index 85c1dc848..a44142dd9 100644 --- a/authentik/providers/oauth2/tests/utils.py +++ b/authentik/providers/oauth2/tests/utils.py @@ -1,4 +1,6 @@ """OAuth test helpers""" +from typing import Any + from django.test import TestCase from jwt import decode @@ -25,7 +27,7 @@ class OAuthTestCase(TestCase): cls.keypair = create_test_cert() super().setUpClass() - def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): + def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]: """Validate that all required fields are set""" key, alg = provider.get_jwt_key() if alg != JWTAlgorithms.HS256: @@ -40,3 +42,4 @@ class OAuthTestCase(TestCase): for key in self.required_jwt_keys: self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") + return jwt diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index 552acba50..ec8fe9ca0 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -261,7 +261,7 @@ class OAuthAuthorizationParams: code.code_challenge = self.code_challenge code.code_challenge_method = self.code_challenge_method - code.expires_at = timezone.now() + timedelta_from_string(self.provider.access_code_validity) + code.expires = timezone.now() + timedelta_from_string(self.provider.access_code_validity) code.scope = self.scope code.nonce = self.nonce code.is_open_id = SCOPE_OPENID in self.scope @@ -525,6 +525,7 @@ class OAuthFulfillmentStage(StageView): user=self.request.user, scope=self.params.scope, request=self.request, + expiry=timedelta_from_string(self.provider.access_code_validity), ) # Check if response_type must include access_token in the response. diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index 0a7fdeb61..5abeb2ab0 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -443,6 +443,7 @@ class TokenView(View): user=self.params.authorization_code.user, scope=self.params.authorization_code.scope, request=self.request, + expiry=timedelta_from_string(self.provider.token_validity), ) if self.params.authorization_code.is_open_id: @@ -478,6 +479,7 @@ class TokenView(View): user=self.params.refresh_token.user, scope=self.params.scope, request=self.request, + expiry=timedelta_from_string(self.provider.token_validity), ) # If the Token has an id_token it's an Authentication request. @@ -509,6 +511,7 @@ class TokenView(View): user=self.params.user, scope=self.params.scope, request=self.request, + expiry=timedelta_from_string(self.provider.token_validity), ) refresh_token.id_token = refresh_token.create_id_token( user=self.params.user, @@ -535,6 +538,7 @@ class TokenView(View): user=self.params.device_code.user, scope=self.params.device_code.scope, request=self.request, + expiry=timedelta_from_string(self.provider.token_validity), ) refresh_token.id_token = refresh_token.create_id_token( user=self.params.device_code.user,