providers/oauth2: optimise and cache signing key, prevent key being loaded multiple times
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
6a3a3e5f8d
commit
01da8e1792
|
@ -4,12 +4,14 @@ import binascii
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from functools import cached_property
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
|
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
|
||||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||||
|
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES
|
||||||
from dacite.core import from_dict
|
from dacite.core import from_dict
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
@ -259,7 +261,8 @@ class OAuth2Provider(Provider):
|
||||||
token.access_token = token.create_access_token(user, request)
|
token.access_token = token.create_access_token(user, request)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def get_jwt_key(self) -> tuple[str, str]:
|
@cached_property
|
||||||
|
def jwt_key(self) -> tuple[str | PRIVATE_KEY_TYPES, str]:
|
||||||
"""Get either the configured certificate or the client secret"""
|
"""Get either the configured certificate or the client secret"""
|
||||||
if not self.signing_key:
|
if not self.signing_key:
|
||||||
# No Certificate at all, assume HS256
|
# No Certificate at all, assume HS256
|
||||||
|
@ -267,9 +270,9 @@ class OAuth2Provider(Provider):
|
||||||
key: CertificateKeyPair = self.signing_key
|
key: CertificateKeyPair = self.signing_key
|
||||||
private_key = key.private_key
|
private_key = key.private_key
|
||||||
if isinstance(private_key, RSAPrivateKey):
|
if isinstance(private_key, RSAPrivateKey):
|
||||||
return key.key_data, JWTAlgorithms.RS256
|
return private_key, JWTAlgorithms.RS256
|
||||||
if isinstance(private_key, EllipticCurvePrivateKey):
|
if isinstance(private_key, EllipticCurvePrivateKey):
|
||||||
return key.key_data, JWTAlgorithms.ES256
|
return private_key, JWTAlgorithms.ES256
|
||||||
raise Exception(f"Invalid private key type: {type(private_key)}")
|
raise Exception(f"Invalid private key type: {type(private_key)}")
|
||||||
|
|
||||||
def get_issuer(self, request: HttpRequest) -> Optional[str]:
|
def get_issuer(self, request: HttpRequest) -> Optional[str]:
|
||||||
|
@ -312,10 +315,9 @@ class OAuth2Provider(Provider):
|
||||||
headers = {}
|
headers = {}
|
||||||
if self.signing_key:
|
if self.signing_key:
|
||||||
headers["kid"] = self.signing_key.kid
|
headers["kid"] = self.signing_key.kid
|
||||||
key, alg = self.get_jwt_key()
|
key, alg = self.jwt_key
|
||||||
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
||||||
self.refresh_from_db()
|
self.refresh_from_db()
|
||||||
# pyright: reportGeneralTypeIssues=false
|
|
||||||
return encode(payload, key, algorithm=alg, headers=headers)
|
return encode(payload, key, algorithm=alg, headers=headers)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
|
@ -143,7 +143,7 @@ class TestTokenClientCredentials(OAuthTestCase):
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
body = loads(response.content.decode())
|
body = loads(response.content.decode())
|
||||||
self.assertEqual(body["token_type"], "bearer")
|
self.assertEqual(body["token_type"], "bearer")
|
||||||
_, alg = self.provider.get_jwt_key()
|
_, alg = self.provider.jwt_key
|
||||||
jwt = decode(
|
jwt = decode(
|
||||||
body["access_token"],
|
body["access_token"],
|
||||||
key=self.provider.signing_key.public_key,
|
key=self.provider.signing_key.public_key,
|
||||||
|
|
|
@ -210,7 +210,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
body = loads(response.content.decode())
|
body = loads(response.content.decode())
|
||||||
self.assertEqual(body["token_type"], "bearer")
|
self.assertEqual(body["token_type"], "bearer")
|
||||||
_, alg = self.provider.get_jwt_key()
|
_, alg = self.provider.jwt_key
|
||||||
jwt = decode(
|
jwt = decode(
|
||||||
body["access_token"],
|
body["access_token"],
|
||||||
key=self.provider.signing_key.public_key,
|
key=self.provider.signing_key.public_key,
|
||||||
|
|
|
@ -29,7 +29,7 @@ class OAuthTestCase(TestCase):
|
||||||
|
|
||||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
|
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
|
||||||
"""Validate that all required fields are set"""
|
"""Validate that all required fields are set"""
|
||||||
key, alg = provider.get_jwt_key()
|
key, alg = provider.jwt_key
|
||||||
if alg != JWTAlgorithms.HS256:
|
if alg != JWTAlgorithms.HS256:
|
||||||
key = provider.signing_key.public_key
|
key = provider.signing_key.public_key
|
||||||
jwt = decode(
|
jwt = decode(
|
||||||
|
|
|
@ -38,7 +38,7 @@ class ProviderInfoView(View):
|
||||||
)
|
)
|
||||||
if SCOPE_OPENID not in scopes:
|
if SCOPE_OPENID not in scopes:
|
||||||
scopes.append(SCOPE_OPENID)
|
scopes.append(SCOPE_OPENID)
|
||||||
_, supported_alg = provider.get_jwt_key()
|
_, supported_alg = provider.jwt_key
|
||||||
return {
|
return {
|
||||||
"issuer": provider.get_issuer(self.request),
|
"issuer": provider.get_issuer(self.request),
|
||||||
"authorization_endpoint": self.request.build_absolute_uri(
|
"authorization_endpoint": self.request.build_absolute_uri(
|
||||||
|
|
Reference in a new issue