providers/oauth2: fix inconsistent expiry encoded in JWT

- access token validity is used for JWTs issues in implicit flows
- general cleanup of how times are set
closes #2581

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-11-10 20:23:24 +01:00
parent bdf50a35cd
commit 3306003f0e
5 changed files with 49 additions and 20 deletions

View file

@ -2,9 +2,8 @@
import base64 import base64
import binascii import binascii
import json import json
import time
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime, timedelta
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional from typing import Any, Optional
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
@ -14,7 +13,7 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from dacite.core import from_dict from dacite.core import from_dict
from django.db import models from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
from django.utils import dateformat, timezone from django.utils import timezone
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from jwt import encode from jwt import encode
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -25,7 +24,7 @@ from authentik.events.models import Event, EventAction
from authentik.events.utils import get_user from authentik.events.utils import get_user
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator from authentik.lib.utils.time import timedelta_string_validator
from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config
from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
@ -237,14 +236,18 @@ class OAuth2Provider(Provider):
) )
def create_refresh_token( def create_refresh_token(
self, user: User, scope: list[str], request: HttpRequest self,
user: User,
scope: list[str],
request: HttpRequest,
expiry: timedelta,
) -> "RefreshToken": ) -> "RefreshToken":
"""Create and populate a RefreshToken object.""" """Create and populate a RefreshToken object."""
token = RefreshToken( token = RefreshToken(
user=user, user=user,
provider=self, provider=self,
refresh_token=base64.urlsafe_b64encode(generate_key().encode()).decode(), refresh_token=base64.urlsafe_b64encode(generate_key().encode()).decode(),
expires=timezone.now() + timedelta_from_string(self.token_validity), expires=timezone.now() + expiry,
scope=scope, scope=scope,
) )
token.access_token = token.create_access_token(user, request) token.access_token = token.create_access_token(user, request)
@ -484,18 +487,21 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
) )
# Convert datetimes into timestamps. # Convert datetimes into timestamps.
now = int(time.time()) now = datetime.now()
iat_time = now iat_time = int(now.timestamp())
exp_time = int(dateformat.format(self.expires, "U")) exp_time = int(self.expires.timestamp())
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
auth_events = Event.objects.filter(action=EventAction.LOGIN, user=get_user(user)).order_by( auth_event = (
"-created" Event.objects.filter(action=EventAction.LOGIN, user=get_user(user))
.order_by("-created")
.first()
) )
# Fallback in case we can't find any login events # Fallback in case we can't find any login events
auth_time = datetime.now() auth_time = now
if auth_events.exists(): if auth_event:
auth_time = auth_events.first().created auth_time = auth_event.created
auth_time = int(dateformat.format(auth_time, "U"))
auth_timestamp = int(auth_time.timestamp())
token = IDToken( token = IDToken(
iss=self.provider.get_issuer(request), iss=self.provider.get_issuer(request),
@ -503,7 +509,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
aud=self.provider.client_id, aud=self.provider.client_id,
exp=exp_time, exp=exp_time,
iat=iat_time, iat=iat_time,
auth_time=auth_time, auth_time=auth_timestamp,
) )
# Include (or not) user standard claims in the id_token. # Include (or not) user standard claims in the id_token.

View file

@ -1,11 +1,13 @@
"""Test authorize view""" """Test authorize view"""
from django.test import RequestFactory from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from django.utils.timezone import now
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.challenge import ChallengeTypes from authentik.flows.challenge import ChallengeTypes
from authentik.lib.generators import generate_id, generate_key from authentik.lib.generators import generate_id, generate_key
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
from authentik.providers.oauth2.models import ( from authentik.providers.oauth2.models import (
AuthorizationCode, AuthorizationCode,
@ -250,6 +252,7 @@ class TestAuthorize(OAuthTestCase):
client_id="test", client_id="test",
authorization_flow=flow, authorization_flow=flow,
redirect_uris="foo://localhost", redirect_uris="foo://localhost",
access_code_validity="seconds=100",
) )
Application.objects.create(name="app", slug="app", provider=provider) Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id() state = generate_id()
@ -277,6 +280,11 @@ class TestAuthorize(OAuthTestCase):
"to": f"foo://localhost?code={code.code}&state={state}", "to": f"foo://localhost?code={code.code}&state={state}",
}, },
) )
self.assertAlmostEqual(
code.expires.timestamp() - now().timestamp(),
timedelta_from_string(provider.access_code_validity).total_seconds(),
delta=5,
)
def test_full_implicit(self): def test_full_implicit(self):
"""Test full authorization""" """Test full authorization"""
@ -288,6 +296,7 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow, authorization_flow=flow,
redirect_uris="http://localhost", redirect_uris="http://localhost",
signing_key=self.keypair, signing_key=self.keypair,
access_code_validity="seconds=100",
) )
Application.objects.create(name="app", slug="app", provider=provider) Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id() state = generate_id()
@ -308,6 +317,7 @@ class TestAuthorize(OAuthTestCase):
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
) )
token: RefreshToken = RefreshToken.objects.filter(user=user).first() token: RefreshToken = RefreshToken.objects.filter(user=user).first()
expires = timedelta_from_string(provider.access_code_validity).total_seconds()
self.assertJSONEqual( self.assertJSONEqual(
response.content.decode(), response.content.decode(),
{ {
@ -316,11 +326,16 @@ class TestAuthorize(OAuthTestCase):
"to": ( "to": (
f"http://localhost#access_token={token.access_token}" f"http://localhost#access_token={token.access_token}"
f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer" f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer"
f"&expires_in=60&state={state}" f"&expires_in={int(expires)}&state={state}"
), ),
}, },
) )
self.validate_jwt(token, provider) jwt = self.validate_jwt(token, provider)
self.assertAlmostEqual(
jwt["exp"] - now().timestamp(),
expires,
delta=5,
)
def test_full_form_post_id_token(self): def test_full_form_post_id_token(self):
"""Test full authorization (form_post response)""" """Test full authorization (form_post response)"""

View file

@ -1,4 +1,6 @@
"""OAuth test helpers""" """OAuth test helpers"""
from typing import Any
from django.test import TestCase from django.test import TestCase
from jwt import decode from jwt import decode
@ -25,7 +27,7 @@ class OAuthTestCase(TestCase):
cls.keypair = create_test_cert() cls.keypair = create_test_cert()
super().setUpClass() super().setUpClass()
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate that all required fields are set""" """Validate that all required fields are set"""
key, alg = provider.get_jwt_key() key, alg = provider.get_jwt_key()
if alg != JWTAlgorithms.HS256: if alg != JWTAlgorithms.HS256:
@ -40,3 +42,4 @@ class OAuthTestCase(TestCase):
for key in self.required_jwt_keys: for key in self.required_jwt_keys:
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
return jwt

View file

@ -261,7 +261,7 @@ class OAuthAuthorizationParams:
code.code_challenge = self.code_challenge code.code_challenge = self.code_challenge
code.code_challenge_method = self.code_challenge_method code.code_challenge_method = self.code_challenge_method
code.expires_at = timezone.now() + timedelta_from_string(self.provider.access_code_validity) code.expires = timezone.now() + timedelta_from_string(self.provider.access_code_validity)
code.scope = self.scope code.scope = self.scope
code.nonce = self.nonce code.nonce = self.nonce
code.is_open_id = SCOPE_OPENID in self.scope code.is_open_id = SCOPE_OPENID in self.scope
@ -525,6 +525,7 @@ class OAuthFulfillmentStage(StageView):
user=self.request.user, user=self.request.user,
scope=self.params.scope, scope=self.params.scope,
request=self.request, request=self.request,
expiry=timedelta_from_string(self.provider.access_code_validity),
) )
# Check if response_type must include access_token in the response. # Check if response_type must include access_token in the response.

View file

@ -443,6 +443,7 @@ class TokenView(View):
user=self.params.authorization_code.user, user=self.params.authorization_code.user,
scope=self.params.authorization_code.scope, scope=self.params.authorization_code.scope,
request=self.request, request=self.request,
expiry=timedelta_from_string(self.provider.token_validity),
) )
if self.params.authorization_code.is_open_id: if self.params.authorization_code.is_open_id:
@ -478,6 +479,7 @@ class TokenView(View):
user=self.params.refresh_token.user, user=self.params.refresh_token.user,
scope=self.params.scope, scope=self.params.scope,
request=self.request, request=self.request,
expiry=timedelta_from_string(self.provider.token_validity),
) )
# If the Token has an id_token it's an Authentication request. # If the Token has an id_token it's an Authentication request.
@ -509,6 +511,7 @@ class TokenView(View):
user=self.params.user, user=self.params.user,
scope=self.params.scope, scope=self.params.scope,
request=self.request, request=self.request,
expiry=timedelta_from_string(self.provider.token_validity),
) )
refresh_token.id_token = refresh_token.create_id_token( refresh_token.id_token = refresh_token.create_id_token(
user=self.params.user, user=self.params.user,
@ -535,6 +538,7 @@ class TokenView(View):
user=self.params.device_code.user, user=self.params.device_code.user,
scope=self.params.device_code.scope, scope=self.params.device_code.scope,
request=self.request, request=self.request,
expiry=timedelta_from_string(self.provider.token_validity),
) )
refresh_token.id_token = refresh_token.create_id_token( refresh_token.id_token = refresh_token.create_id_token(
user=self.params.device_code.user, user=self.params.device_code.user,