diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index 2b6e77a0c..a4e402cbc 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -4,12 +4,14 @@ import binascii import json from dataclasses import asdict, dataclass, field from datetime import datetime, timedelta +from functools import cached_property from hashlib import sha256 from typing import Any, Optional from urllib.parse import urlparse, urlunparse from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES from dacite.core import from_dict from django.db import models from django.http import HttpRequest @@ -259,7 +261,8 @@ class OAuth2Provider(Provider): token.access_token = token.create_access_token(user, request) return token - def get_jwt_key(self) -> tuple[str, str]: + @cached_property + def jwt_key(self) -> tuple[str | PRIVATE_KEY_TYPES, str]: """Get either the configured certificate or the client secret""" if not self.signing_key: # No Certificate at all, assume HS256 @@ -267,9 +270,9 @@ class OAuth2Provider(Provider): key: CertificateKeyPair = self.signing_key private_key = key.private_key if isinstance(private_key, RSAPrivateKey): - return key.key_data, JWTAlgorithms.RS256 + return private_key, JWTAlgorithms.RS256 if isinstance(private_key, EllipticCurvePrivateKey): - return key.key_data, JWTAlgorithms.ES256 + return private_key, JWTAlgorithms.ES256 raise Exception(f"Invalid private key type: {type(private_key)}") def get_issuer(self, request: HttpRequest) -> Optional[str]: @@ -312,10 +315,9 @@ class OAuth2Provider(Provider): headers = {} if self.signing_key: headers["kid"] = self.signing_key.kid - key, alg = self.get_jwt_key() + key, alg = self.jwt_key # If the provider does not have an RSA Key assigned, it was switched to Symmetric self.refresh_from_db() - # pyright: reportGeneralTypeIssues=false return encode(payload, key, algorithm=alg, headers=headers) class Meta: diff --git a/authentik/providers/oauth2/tests/test_token_cc.py b/authentik/providers/oauth2/tests/test_token_cc.py index fc7f00d04..6f9f449b4 100644 --- a/authentik/providers/oauth2/tests/test_token_cc.py +++ b/authentik/providers/oauth2/tests/test_token_cc.py @@ -143,7 +143,7 @@ class TestTokenClientCredentials(OAuthTestCase): self.assertEqual(response.status_code, 200) body = loads(response.content.decode()) self.assertEqual(body["token_type"], "bearer") - _, alg = self.provider.get_jwt_key() + _, alg = self.provider.jwt_key jwt = decode( body["access_token"], key=self.provider.signing_key.public_key, diff --git a/authentik/providers/oauth2/tests/test_token_cc_jwt_source.py b/authentik/providers/oauth2/tests/test_token_cc_jwt_source.py index 7890bcfdb..3c69a6054 100644 --- a/authentik/providers/oauth2/tests/test_token_cc_jwt_source.py +++ b/authentik/providers/oauth2/tests/test_token_cc_jwt_source.py @@ -210,7 +210,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase): self.assertEqual(response.status_code, 200) body = loads(response.content.decode()) self.assertEqual(body["token_type"], "bearer") - _, alg = self.provider.get_jwt_key() + _, alg = self.provider.jwt_key jwt = decode( body["access_token"], key=self.provider.signing_key.public_key, diff --git a/authentik/providers/oauth2/tests/utils.py b/authentik/providers/oauth2/tests/utils.py index a44142dd9..bfa45324e 100644 --- a/authentik/providers/oauth2/tests/utils.py +++ b/authentik/providers/oauth2/tests/utils.py @@ -29,7 +29,7 @@ class OAuthTestCase(TestCase): def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]: """Validate that all required fields are set""" - key, alg = provider.get_jwt_key() + key, alg = provider.jwt_key if alg != JWTAlgorithms.HS256: key = provider.signing_key.public_key jwt = decode( diff --git a/authentik/providers/oauth2/views/provider.py b/authentik/providers/oauth2/views/provider.py index a23b14f31..7cdd0b781 100644 --- a/authentik/providers/oauth2/views/provider.py +++ b/authentik/providers/oauth2/views/provider.py @@ -38,7 +38,7 @@ class ProviderInfoView(View): ) if SCOPE_OPENID not in scopes: scopes.append(SCOPE_OPENID) - _, supported_alg = provider.get_jwt_key() + _, supported_alg = provider.jwt_key return { "issuer": provider.get_issuer(self.request), "authorization_endpoint": self.request.build_absolute_uri(