providers/oauth2: don't use policy cache for token requests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2023-02-07 23:53:50 +01:00
parent c6e638ddc2
commit ec9085ff06
No known key found for this signature in database
1 changed files with 8 additions and 5 deletions

View File

@ -109,6 +109,9 @@ class TokenParams:
): ):
user = self.user if self.user else get_anonymous_user() user = self.user if self.user else get_anonymous_user()
engine = PolicyEngine(app, user, request) engine = PolicyEngine(app, user, request)
# Don't cache as for client_credentials flows the user will not be set
# so we'll get generic cache results
engine.use_cache = False
engine.request.context["oauth_scopes"] = self.scope engine.request.context["oauth_scopes"] = self.scope
engine.request.context["oauth_grant_type"] = self.grant_type engine.request.context["oauth_grant_type"] = self.grant_type
engine.request.context["oauth_code_verifier"] = self.code_verifier engine.request.context["oauth_code_verifier"] = self.code_verifier
@ -322,16 +325,16 @@ class TokenParams:
assertion, options={"verify_signature": False} assertion, options={"verify_signature": False}
) )
except (PyJWTError, ValueError, TypeError, AttributeError) as exc: except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
LOGGER.warning("failed to parse jwt for kid lookup", exc=exc) LOGGER.warning("failed to parse JWT for kid lookup", exc=exc)
raise TokenError("invalid_grant") raise TokenError("invalid_grant")
expected_kid = decode_unvalidated["header"]["kid"] expected_kid = decode_unvalidated["header"]["kid"]
for source in self.provider.jwks_sources.filter( for source in self.provider.jwks_sources.filter(
oidc_jwks__keys__contains=[{"kid": expected_kid}] oidc_jwks__keys__contains=[{"kid": expected_kid}]
): ):
LOGGER.debug("verifying jwt with source", source=source.slug) LOGGER.debug("verifying JWT with source", source=source.slug)
keys = source.oidc_jwks.get("keys", []) keys = source.oidc_jwks.get("keys", [])
for key in keys: for key in keys:
LOGGER.debug("verifying jwt with key", source=source.slug, key=key.get("kid")) LOGGER.debug("verifying JWT with key", source=source.slug, key=key.get("kid"))
try: try:
parsed_key = PyJWK.from_dict(key) parsed_key = PyJWK.from_dict(key)
token = decode( token = decode(
@ -345,13 +348,13 @@ class TokenParams:
# AttributeError is raised when the configured JWK is a private key # AttributeError is raised when the configured JWK is a private key
# and not a public key # and not a public key
except (PyJWTError, ValueError, TypeError, AttributeError) as exc: except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
LOGGER.warning("failed to verify jwt", exc=exc, source=source.slug) LOGGER.warning("failed to verify JWT", exc=exc, source=source.slug)
if not token: if not token:
LOGGER.warning("No token could be verified") LOGGER.warning("No token could be verified")
raise TokenError("invalid_grant") raise TokenError("invalid_grant")
LOGGER.debug("successfully verified jwt with source", source=source.slug) LOGGER.info("successfully verified JWT with source", source=source.slug)
if "exp" in token: if "exp" in token:
exp = datetime.fromtimestamp(token["exp"]) exp = datetime.fromtimestamp(token["exp"])