providers/oauth2: optimise client credentials JWT database lookup (#4606)
This commit is contained in:
parent
ccf956d5c6
commit
798245b8db
|
@ -12,7 +12,7 @@ from django.utils.timezone import datetime, now
|
||||||
from django.views import View
|
from django.views import View
|
||||||
from django.views.decorators.csrf import csrf_exempt
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
from guardian.shortcuts import get_anonymous_user
|
from guardian.shortcuts import get_anonymous_user
|
||||||
from jwt import PyJWK, PyJWTError, decode
|
from jwt import PyJWK, PyJWT, PyJWTError, decode
|
||||||
from sentry_sdk.hub import Hub
|
from sentry_sdk.hub import Hub
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
@ -306,7 +306,24 @@ class TokenParams:
|
||||||
|
|
||||||
source: Optional[OAuthSource] = None
|
source: Optional[OAuthSource] = None
|
||||||
parsed_key: Optional[PyJWK] = None
|
parsed_key: Optional[PyJWK] = None
|
||||||
for source in self.provider.jwks_sources.all():
|
|
||||||
|
# Fully decode the JWT without verifying the signature, so we can get access to
|
||||||
|
# the header.
|
||||||
|
# Get the Key ID from the header, and use that to optimise our source query to only find
|
||||||
|
# sources that have a JWK for that Key ID
|
||||||
|
# The Key ID doesn't have a fixed format, but must match between an issued JWT
|
||||||
|
# and whatever is returned by the JWKS endpoint
|
||||||
|
try:
|
||||||
|
decode_unvalidated = PyJWT().decode_complete(
|
||||||
|
assertion, options={"verify_signature": False}
|
||||||
|
)
|
||||||
|
except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
|
||||||
|
LOGGER.warning("failed to parse jwt for kid lookup", exc=exc)
|
||||||
|
raise TokenError("invalid_grant")
|
||||||
|
expected_kid = decode_unvalidated["header"]["kid"]
|
||||||
|
for source in self.provider.jwks_sources.filter(
|
||||||
|
oidc_jwks__keys__contains=[{"kid": expected_kid}]
|
||||||
|
):
|
||||||
LOGGER.debug("verifying jwt with source", source=source.slug)
|
LOGGER.debug("verifying jwt with source", source=source.slug)
|
||||||
keys = source.oidc_jwks.get("keys", [])
|
keys = source.oidc_jwks.get("keys", [])
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
|
Reference in New Issue