providers/oauth2: optimise and cache signing key, prevent key being loaded multiple times

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-12-23 12:04:31 +01:00
parent 6a3a3e5f8d
commit 01da8e1792
No known key found for this signature in database
5 changed files with 11 additions and 9 deletions

View file

@ -4,12 +4,14 @@ import binascii
import json import json
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import cached_property
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional from typing import Any, Optional
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES
from dacite.core import from_dict from dacite.core import from_dict
from django.db import models from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
@ -259,7 +261,8 @@ class OAuth2Provider(Provider):
token.access_token = token.create_access_token(user, request) token.access_token = token.create_access_token(user, request)
return token return token
def get_jwt_key(self) -> tuple[str, str]: @cached_property
def jwt_key(self) -> tuple[str | PRIVATE_KEY_TYPES, str]:
"""Get either the configured certificate or the client secret""" """Get either the configured certificate or the client secret"""
if not self.signing_key: if not self.signing_key:
# No Certificate at all, assume HS256 # No Certificate at all, assume HS256
@ -267,9 +270,9 @@ class OAuth2Provider(Provider):
key: CertificateKeyPair = self.signing_key key: CertificateKeyPair = self.signing_key
private_key = key.private_key private_key = key.private_key
if isinstance(private_key, RSAPrivateKey): if isinstance(private_key, RSAPrivateKey):
return key.key_data, JWTAlgorithms.RS256 return private_key, JWTAlgorithms.RS256
if isinstance(private_key, EllipticCurvePrivateKey): if isinstance(private_key, EllipticCurvePrivateKey):
return key.key_data, JWTAlgorithms.ES256 return private_key, JWTAlgorithms.ES256
raise Exception(f"Invalid private key type: {type(private_key)}") raise Exception(f"Invalid private key type: {type(private_key)}")
def get_issuer(self, request: HttpRequest) -> Optional[str]: def get_issuer(self, request: HttpRequest) -> Optional[str]:
@ -312,10 +315,9 @@ class OAuth2Provider(Provider):
headers = {} headers = {}
if self.signing_key: if self.signing_key:
headers["kid"] = self.signing_key.kid headers["kid"] = self.signing_key.kid
key, alg = self.get_jwt_key() key, alg = self.jwt_key
# If the provider does not have an RSA Key assigned, it was switched to Symmetric # If the provider does not have an RSA Key assigned, it was switched to Symmetric
self.refresh_from_db() self.refresh_from_db()
# pyright: reportGeneralTypeIssues=false
return encode(payload, key, algorithm=alg, headers=headers) return encode(payload, key, algorithm=alg, headers=headers)
class Meta: class Meta:

View file

@ -143,7 +143,7 @@ class TestTokenClientCredentials(OAuthTestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
body = loads(response.content.decode()) body = loads(response.content.decode())
self.assertEqual(body["token_type"], "bearer") self.assertEqual(body["token_type"], "bearer")
_, alg = self.provider.get_jwt_key() _, alg = self.provider.jwt_key
jwt = decode( jwt = decode(
body["access_token"], body["access_token"],
key=self.provider.signing_key.public_key, key=self.provider.signing_key.public_key,

View file

@ -210,7 +210,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
body = loads(response.content.decode()) body = loads(response.content.decode())
self.assertEqual(body["token_type"], "bearer") self.assertEqual(body["token_type"], "bearer")
_, alg = self.provider.get_jwt_key() _, alg = self.provider.jwt_key
jwt = decode( jwt = decode(
body["access_token"], body["access_token"],
key=self.provider.signing_key.public_key, key=self.provider.signing_key.public_key,

View file

@ -29,7 +29,7 @@ class OAuthTestCase(TestCase):
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]: def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate that all required fields are set""" """Validate that all required fields are set"""
key, alg = provider.get_jwt_key() key, alg = provider.jwt_key
if alg != JWTAlgorithms.HS256: if alg != JWTAlgorithms.HS256:
key = provider.signing_key.public_key key = provider.signing_key.public_key
jwt = decode( jwt = decode(

View file

@ -38,7 +38,7 @@ class ProviderInfoView(View):
) )
if SCOPE_OPENID not in scopes: if SCOPE_OPENID not in scopes:
scopes.append(SCOPE_OPENID) scopes.append(SCOPE_OPENID)
_, supported_alg = provider.get_jwt_key() _, supported_alg = provider.jwt_key
return { return {
"issuer": provider.get_issuer(self.request), "issuer": provider.get_issuer(self.request),
"authorization_endpoint": self.request.build_absolute_uri( "authorization_endpoint": self.request.build_absolute_uri(