providers/oauth2: fix inconsistent expiry encoded in JWT

- access token validity is used for JWTs issues in implicit flows
- general cleanup of how times are set
closes #2581

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-11-10 20:23:24 +01:00
parent bdf50a35cd
commit 3306003f0e
5 changed files with 49 additions and 20 deletions

View file

@ -2,9 +2,8 @@
import base64
import binascii
import json
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime
from datetime import datetime, timedelta
from hashlib import sha256
from typing import Any, Optional
from urllib.parse import urlparse, urlunparse
@ -14,7 +13,7 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from dacite.core import from_dict
from django.db import models
from django.http import HttpRequest
from django.utils import dateformat, timezone
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from jwt import encode
from rest_framework.serializers import Serializer
@ -25,7 +24,7 @@ from authentik.events.models import Event, EventAction
from authentik.events.utils import get_user
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
from authentik.lib.models import SerializerModel
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
from authentik.lib.utils.time import timedelta_string_validator
from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config
from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT
from authentik.sources.oauth.models import OAuthSource
@ -237,14 +236,18 @@ class OAuth2Provider(Provider):
)
def create_refresh_token(
self, user: User, scope: list[str], request: HttpRequest
self,
user: User,
scope: list[str],
request: HttpRequest,
expiry: timedelta,
) -> "RefreshToken":
"""Create and populate a RefreshToken object."""
token = RefreshToken(
user=user,
provider=self,
refresh_token=base64.urlsafe_b64encode(generate_key().encode()).decode(),
expires=timezone.now() + timedelta_from_string(self.token_validity),
expires=timezone.now() + expiry,
scope=scope,
)
token.access_token = token.create_access_token(user, request)
@ -484,18 +487,21 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
)
# Convert datetimes into timestamps.
now = int(time.time())
iat_time = now
exp_time = int(dateformat.format(self.expires, "U"))
now = datetime.now()
iat_time = int(now.timestamp())
exp_time = int(self.expires.timestamp())
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
auth_events = Event.objects.filter(action=EventAction.LOGIN, user=get_user(user)).order_by(
"-created"
auth_event = (
Event.objects.filter(action=EventAction.LOGIN, user=get_user(user))
.order_by("-created")
.first()
)
# Fallback in case we can't find any login events
auth_time = datetime.now()
if auth_events.exists():
auth_time = auth_events.first().created
auth_time = int(dateformat.format(auth_time, "U"))
auth_time = now
if auth_event:
auth_time = auth_event.created
auth_timestamp = int(auth_time.timestamp())
token = IDToken(
iss=self.provider.get_issuer(request),
@ -503,7 +509,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
aud=self.provider.client_id,
exp=exp_time,
iat=iat_time,
auth_time=auth_time,
auth_time=auth_timestamp,
)
# Include (or not) user standard claims in the id_token.

View file

@ -1,11 +1,13 @@
"""Test authorize view"""
from django.test import RequestFactory
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.challenge import ChallengeTypes
from authentik.lib.generators import generate_id, generate_key
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
from authentik.providers.oauth2.models import (
AuthorizationCode,
@ -250,6 +252,7 @@ class TestAuthorize(OAuthTestCase):
client_id="test",
authorization_flow=flow,
redirect_uris="foo://localhost",
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@ -277,6 +280,11 @@ class TestAuthorize(OAuthTestCase):
"to": f"foo://localhost?code={code.code}&state={state}",
},
)
self.assertAlmostEqual(
code.expires.timestamp() - now().timestamp(),
timedelta_from_string(provider.access_code_validity).total_seconds(),
delta=5,
)
def test_full_implicit(self):
"""Test full authorization"""
@ -288,6 +296,7 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=flow,
redirect_uris="http://localhost",
signing_key=self.keypair,
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@ -308,6 +317,7 @@ class TestAuthorize(OAuthTestCase):
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
token: RefreshToken = RefreshToken.objects.filter(user=user).first()
expires = timedelta_from_string(provider.access_code_validity).total_seconds()
self.assertJSONEqual(
response.content.decode(),
{
@ -316,11 +326,16 @@ class TestAuthorize(OAuthTestCase):
"to": (
f"http://localhost#access_token={token.access_token}"
f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer"
f"&expires_in=60&state={state}"
f"&expires_in={int(expires)}&state={state}"
),
},
)
self.validate_jwt(token, provider)
jwt = self.validate_jwt(token, provider)
self.assertAlmostEqual(
jwt["exp"] - now().timestamp(),
expires,
delta=5,
)
def test_full_form_post_id_token(self):
"""Test full authorization (form_post response)"""

View file

@ -1,4 +1,6 @@
"""OAuth test helpers"""
from typing import Any
from django.test import TestCase
from jwt import decode
@ -25,7 +27,7 @@ class OAuthTestCase(TestCase):
cls.keypair = create_test_cert()
super().setUpClass()
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate that all required fields are set"""
key, alg = provider.get_jwt_key()
if alg != JWTAlgorithms.HS256:
@ -40,3 +42,4 @@ class OAuthTestCase(TestCase):
for key in self.required_jwt_keys:
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
return jwt

View file

@ -261,7 +261,7 @@ class OAuthAuthorizationParams:
code.code_challenge = self.code_challenge
code.code_challenge_method = self.code_challenge_method
code.expires_at = timezone.now() + timedelta_from_string(self.provider.access_code_validity)
code.expires = timezone.now() + timedelta_from_string(self.provider.access_code_validity)
code.scope = self.scope
code.nonce = self.nonce
code.is_open_id = SCOPE_OPENID in self.scope
@ -525,6 +525,7 @@ class OAuthFulfillmentStage(StageView):
user=self.request.user,
scope=self.params.scope,
request=self.request,
expiry=timedelta_from_string(self.provider.access_code_validity),
)
# Check if response_type must include access_token in the response.

View file

@ -443,6 +443,7 @@ class TokenView(View):
user=self.params.authorization_code.user,
scope=self.params.authorization_code.scope,
request=self.request,
expiry=timedelta_from_string(self.provider.token_validity),
)
if self.params.authorization_code.is_open_id:
@ -478,6 +479,7 @@ class TokenView(View):
user=self.params.refresh_token.user,
scope=self.params.scope,
request=self.request,
expiry=timedelta_from_string(self.provider.token_validity),
)
# If the Token has an id_token it's an Authentication request.
@ -509,6 +511,7 @@ class TokenView(View):
user=self.params.user,
scope=self.params.scope,
request=self.request,
expiry=timedelta_from_string(self.provider.token_validity),
)
refresh_token.id_token = refresh_token.create_id_token(
user=self.params.user,
@ -535,6 +538,7 @@ class TokenView(View):
user=self.params.device_code.user,
scope=self.params.device_code.scope,
request=self.request,
expiry=timedelta_from_string(self.provider.token_validity),
)
refresh_token.id_token = refresh_token.create_id_token(
user=self.params.device_code.user,