providers/oauth2: make PKCE required for public clients

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2023-02-17 18:08:39 +01:00
parent f749027143
commit c6ead3dc49
No known key found for this signature in database
2 changed files with 20 additions and 5 deletions

View File

@ -147,10 +147,14 @@ class AuthorizeError(OAuth2Error):
error: str, error: str,
grant_type: str, grant_type: str,
state: str, state: str,
description: Optional[str] = None,
): ):
super().__init__() super().__init__()
self.error = error self.error = error
self.description = self.errors[error] if description:
self.description = description
else:
self.description = self.errors[error]
self.redirect_uri = redirect_uri self.redirect_uri = redirect_uri
self.grant_type = grant_type self.grant_type = grant_type
self.state = state self.state = state

View File

@ -50,6 +50,7 @@ from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import ( from authentik.providers.oauth2.models import (
AccessToken, AccessToken,
AuthorizationCode, AuthorizationCode,
ClientTypes,
GrantTypes, GrantTypes,
OAuth2Provider, OAuth2Provider,
ResponseMode, ResponseMode,
@ -158,13 +159,14 @@ class OAuthAuthorizationParams:
request=query_dict.get("request", None), request=query_dict.get("request", None),
max_age=int(max_age) if max_age else None, max_age=int(max_age) if max_age else None,
code_challenge=query_dict.get("code_challenge"), code_challenge=query_dict.get("code_challenge"),
code_challenge_method=query_dict.get("code_challenge_method"), code_challenge_method=query_dict.get("code_challenge_method", "plain"),
) )
def __post_init__(self): def __post_init__(self):
try: self.provider: OAuth2Provider = OAuth2Provider.objects.filter(
self.provider: OAuth2Provider = OAuth2Provider.objects.get(client_id=self.client_id) client_id=self.client_id
except OAuth2Provider.DoesNotExist: ).first()
if not self.provider:
LOGGER.warning("Invalid client identifier", client_id=self.client_id) LOGGER.warning("Invalid client identifier", client_id=self.client_id)
raise ClientIdError(client_id=self.client_id) raise ClientIdError(client_id=self.client_id)
self.check_redirect_uri() self.check_redirect_uri()
@ -251,6 +253,15 @@ class OAuthAuthorizationParams:
def check_code_challenge(self): def check_code_challenge(self):
"""PKCE validation of the transformation method.""" """PKCE validation of the transformation method."""
if self.code_challenge and self.code_challenge_method not in ["plain", "S256"]: if self.code_challenge and self.code_challenge_method not in ["plain", "S256"]:
raise AuthorizeError(
self.redirect_uri,
"invalid_request",
self.grant_type,
self.state,
f"Unsupported challenge method {self.code_challenge_method}",
)
if self.provider.client_type == ClientTypes.PUBLIC and not self.code_challenge:
LOGGER.warning("Public clients require PKCE", client_id=self.provider.client_id)
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state) raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
def create_code(self, request: HttpRequest) -> AuthorizationCode: def create_code(self, request: HttpRequest) -> AuthorizationCode: