diff --git a/authentik/providers/oauth2/utils.py b/authentik/providers/oauth2/utils.py index a022ef04d..a8c65d5f8 100644 --- a/authentik/providers/oauth2/utils.py +++ b/authentik/providers/oauth2/utils.py @@ -50,7 +50,7 @@ def cors_allow(request: HttpRequest, response: HttpResponse, *allowed_origins: s if not allowed: LOGGER.warning( "CORS: Origin is not an allowed origin", - requested=origin, + requested=received_origin, allowed=allowed_origins, ) return response diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index 7c78eed16..40bb55c1e 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -30,6 +30,7 @@ LOGGER = get_logger() @dataclass +# pylint: disable=too-many-instance-attributes class TokenParams: """Token params""" @@ -40,6 +41,8 @@ class TokenParams: state: str scope: list[str] + provider: OAuth2Provider + authorization_code: Optional[AuthorizationCode] = None refresh_token: Optional[RefreshToken] = None @@ -47,35 +50,33 @@ class TokenParams: raw_code: InitVar[str] = "" raw_token: InitVar[str] = "" + request: InitVar[Optional[HttpRequest]] = None @staticmethod - def from_request(request: HttpRequest) -> "TokenParams": - """Extract Token Parameters from http request""" - client_id, client_secret = extract_client_auth(request) - + def parse( + request: HttpRequest, + provider: OAuth2Provider, + client_id: str, + client_secret: str, + ) -> "TokenParams": return TokenParams( + # Init vars + raw_code=request.POST.get("code", ""), + raw_token=request.POST.get("refresh_token", ""), + request=request, + # Regular params + provider=provider, client_id=client_id, client_secret=client_secret, redirect_uri=request.POST.get("redirect_uri", ""), grant_type=request.POST.get("grant_type", ""), - raw_code=request.POST.get("code", ""), - raw_token=request.POST.get("refresh_token", ""), state=request.POST.get("state", ""), scope=request.POST.get("scope", "").split(), # PKCE parameter. code_verifier=request.POST.get("code_verifier"), ) - def __post_init__(self, raw_code, raw_token): - try: - provider: OAuth2Provider = OAuth2Provider.objects.get( - client_id=self.client_id - ) - self.provider = provider - except OAuth2Provider.DoesNotExist: - LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id) - raise TokenError("invalid_client") - + def __post_init__(self, raw_code: str, raw_token: str, request: HttpRequest): if self.provider.client_type == ClientTypes.CONFIDENTIAL: if self.provider.client_secret != self.client_secret: LOGGER.warning( @@ -87,7 +88,6 @@ class TokenParams: if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: self.__post_init_code(raw_code) - elif self.grant_type == GRANT_TYPE_REFRESH_TOKEN: if not raw_token: LOGGER.warning("Missing refresh token") @@ -159,13 +159,14 @@ class TokenParams: class TokenView(View): """Generate tokens for clients""" + provider: Optional[OAuth2Provider] = None params: Optional[TokenParams] = None def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: response = super().dispatch(request, *args, **kwargs) allowed_origins = [] - if self.params: - allowed_origins = self.params.provider.redirect_uris.split("\n") + if self.provider: + allowed_origins = self.provider.redirect_uris.split("\n") cors_allow(self.request, response, *allowed_origins) return response @@ -175,19 +176,32 @@ class TokenView(View): def post(self, request: HttpRequest) -> HttpResponse: """Generate tokens for clients""" try: - self.params = TokenParams.from_request(request) + client_id, client_secret = extract_client_auth(request) + try: + self.provider: OAuth2Provider = OAuth2Provider.objects.get( + client_id=client_id + ) + except OAuth2Provider.DoesNotExist: + LOGGER.warning( + "OAuth2Provider does not exist", client_id=self.client_id + ) + raise TokenError("invalid_client") + + self.params = TokenParams.parse( + request, self.provider, client_id, client_secret + ) if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: - return TokenResponse(self.create_code_response_dic()) + return TokenResponse(self.create_code_response()) if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN: - return TokenResponse(self.create_refresh_response_dic()) + return TokenResponse(self.create_refresh_response()) raise ValueError(f"Invalid grant_type: {self.params.grant_type}") except TokenError as error: return TokenResponse(error.create_dict(), status=400) except UserAuthError as error: return TokenResponse(error.create_dict(), status=403) - def create_code_response_dic(self) -> dict[str, Any]: + def create_code_response(self) -> dict[str, Any]: """See https://tools.ietf.org/html/rfc6749#section-4.1""" refresh_token = self.params.authorization_code.provider.create_refresh_token( @@ -211,7 +225,7 @@ class TokenView(View): # We don't need to store the code anymore. self.params.authorization_code.delete() - response_dict = { + return { "access_token": refresh_token.access_token, "refresh_token": refresh_token.refresh_token, "token_type": "bearer", @@ -223,9 +237,7 @@ class TokenView(View): "id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()), } - return response_dict - - def create_refresh_response_dic(self) -> dict[str, Any]: + def create_refresh_response(self) -> dict[str, Any]: """See https://tools.ietf.org/html/rfc6749#section-6""" unauthorized_scopes = set(self.params.scope) - set( @@ -256,7 +268,7 @@ class TokenView(View): # Forget the old token. self.params.refresh_token.delete() - dic = { + return { "access_token": refresh_token.access_token, "refresh_token": refresh_token.refresh_token, "token_type": "bearer", @@ -267,5 +279,3 @@ class TokenView(View): ), "id_token": self.params.provider.encode(refresh_token.id_token.to_dict()), } - - return dic