diff --git a/authentik/providers/oauth2/errors.py b/authentik/providers/oauth2/errors.py index 53e338256..f82b434bd 100644 --- a/authentik/providers/oauth2/errors.py +++ b/authentik/providers/oauth2/errors.py @@ -147,10 +147,14 @@ class AuthorizeError(OAuth2Error): error: str, grant_type: str, state: str, + description: Optional[str] = None, ): super().__init__() self.error = error - self.description = self.errors[error] + if description: + self.description = description + else: + self.description = self.errors[error] self.redirect_uri = redirect_uri self.grant_type = grant_type self.state = state diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index 98524aeff..34a068d39 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -50,6 +50,7 @@ from authentik.providers.oauth2.id_token import IDToken from authentik.providers.oauth2.models import ( AccessToken, AuthorizationCode, + ClientTypes, GrantTypes, OAuth2Provider, ResponseMode, @@ -158,13 +159,14 @@ class OAuthAuthorizationParams: request=query_dict.get("request", None), max_age=int(max_age) if max_age else None, code_challenge=query_dict.get("code_challenge"), - code_challenge_method=query_dict.get("code_challenge_method"), + code_challenge_method=query_dict.get("code_challenge_method", "plain"), ) def __post_init__(self): - try: - self.provider: OAuth2Provider = OAuth2Provider.objects.get(client_id=self.client_id) - except OAuth2Provider.DoesNotExist: + self.provider: OAuth2Provider = OAuth2Provider.objects.filter( + client_id=self.client_id + ).first() + if not self.provider: LOGGER.warning("Invalid client identifier", client_id=self.client_id) raise ClientIdError(client_id=self.client_id) self.check_redirect_uri() @@ -251,6 +253,15 @@ class OAuthAuthorizationParams: def check_code_challenge(self): """PKCE validation of the transformation method.""" if self.code_challenge and self.code_challenge_method not in ["plain", "S256"]: + raise AuthorizeError( + self.redirect_uri, + "invalid_request", + self.grant_type, + self.state, + f"Unsupported challenge method {self.code_challenge_method}", + ) + if self.provider.client_type == ClientTypes.PUBLIC and not self.code_challenge: + LOGGER.warning("Public clients require PKCE", client_id=self.provider.client_id) raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state) def create_code(self, request: HttpRequest) -> AuthorizationCode: