providers/oauth2: fix CORS headers not being set for unsuccessful requests

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-07-03 15:44:36 +02:00
parent 84ec70c2a2
commit 3e60e956f4
2 changed files with 41 additions and 31 deletions

View file

@ -50,7 +50,7 @@ def cors_allow(request: HttpRequest, response: HttpResponse, *allowed_origins: s
if not allowed: if not allowed:
LOGGER.warning( LOGGER.warning(
"CORS: Origin is not an allowed origin", "CORS: Origin is not an allowed origin",
requested=origin, requested=received_origin,
allowed=allowed_origins, allowed=allowed_origins,
) )
return response return response

View file

@ -30,6 +30,7 @@ LOGGER = get_logger()
@dataclass @dataclass
# pylint: disable=too-many-instance-attributes
class TokenParams: class TokenParams:
"""Token params""" """Token params"""
@ -40,6 +41,8 @@ class TokenParams:
state: str state: str
scope: list[str] scope: list[str]
provider: OAuth2Provider
authorization_code: Optional[AuthorizationCode] = None authorization_code: Optional[AuthorizationCode] = None
refresh_token: Optional[RefreshToken] = None refresh_token: Optional[RefreshToken] = None
@ -47,35 +50,33 @@ class TokenParams:
raw_code: InitVar[str] = "" raw_code: InitVar[str] = ""
raw_token: InitVar[str] = "" raw_token: InitVar[str] = ""
request: InitVar[Optional[HttpRequest]] = None
@staticmethod @staticmethod
def from_request(request: HttpRequest) -> "TokenParams": def parse(
"""Extract Token Parameters from http request""" request: HttpRequest,
client_id, client_secret = extract_client_auth(request) provider: OAuth2Provider,
client_id: str,
client_secret: str,
) -> "TokenParams":
return TokenParams( return TokenParams(
# Init vars
raw_code=request.POST.get("code", ""),
raw_token=request.POST.get("refresh_token", ""),
request=request,
# Regular params
provider=provider,
client_id=client_id, client_id=client_id,
client_secret=client_secret, client_secret=client_secret,
redirect_uri=request.POST.get("redirect_uri", ""), redirect_uri=request.POST.get("redirect_uri", ""),
grant_type=request.POST.get("grant_type", ""), grant_type=request.POST.get("grant_type", ""),
raw_code=request.POST.get("code", ""),
raw_token=request.POST.get("refresh_token", ""),
state=request.POST.get("state", ""), state=request.POST.get("state", ""),
scope=request.POST.get("scope", "").split(), scope=request.POST.get("scope", "").split(),
# PKCE parameter. # PKCE parameter.
code_verifier=request.POST.get("code_verifier"), code_verifier=request.POST.get("code_verifier"),
) )
def __post_init__(self, raw_code, raw_token): def __post_init__(self, raw_code: str, raw_token: str, request: HttpRequest):
try:
provider: OAuth2Provider = OAuth2Provider.objects.get(
client_id=self.client_id
)
self.provider = provider
except OAuth2Provider.DoesNotExist:
LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
raise TokenError("invalid_client")
if self.provider.client_type == ClientTypes.CONFIDENTIAL: if self.provider.client_type == ClientTypes.CONFIDENTIAL:
if self.provider.client_secret != self.client_secret: if self.provider.client_secret != self.client_secret:
LOGGER.warning( LOGGER.warning(
@ -87,7 +88,6 @@ class TokenParams:
if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
self.__post_init_code(raw_code) self.__post_init_code(raw_code)
elif self.grant_type == GRANT_TYPE_REFRESH_TOKEN: elif self.grant_type == GRANT_TYPE_REFRESH_TOKEN:
if not raw_token: if not raw_token:
LOGGER.warning("Missing refresh token") LOGGER.warning("Missing refresh token")
@ -159,13 +159,14 @@ class TokenParams:
class TokenView(View): class TokenView(View):
"""Generate tokens for clients""" """Generate tokens for clients"""
provider: Optional[OAuth2Provider] = None
params: Optional[TokenParams] = None params: Optional[TokenParams] = None
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
response = super().dispatch(request, *args, **kwargs) response = super().dispatch(request, *args, **kwargs)
allowed_origins = [] allowed_origins = []
if self.params: if self.provider:
allowed_origins = self.params.provider.redirect_uris.split("\n") allowed_origins = self.provider.redirect_uris.split("\n")
cors_allow(self.request, response, *allowed_origins) cors_allow(self.request, response, *allowed_origins)
return response return response
@ -175,19 +176,32 @@ class TokenView(View):
def post(self, request: HttpRequest) -> HttpResponse: def post(self, request: HttpRequest) -> HttpResponse:
"""Generate tokens for clients""" """Generate tokens for clients"""
try: try:
self.params = TokenParams.from_request(request) client_id, client_secret = extract_client_auth(request)
try:
self.provider: OAuth2Provider = OAuth2Provider.objects.get(
client_id=client_id
)
except OAuth2Provider.DoesNotExist:
LOGGER.warning(
"OAuth2Provider does not exist", client_id=self.client_id
)
raise TokenError("invalid_client")
self.params = TokenParams.parse(
request, self.provider, client_id, client_secret
)
if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
return TokenResponse(self.create_code_response_dic()) return TokenResponse(self.create_code_response())
if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN: if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN:
return TokenResponse(self.create_refresh_response_dic()) return TokenResponse(self.create_refresh_response())
raise ValueError(f"Invalid grant_type: {self.params.grant_type}") raise ValueError(f"Invalid grant_type: {self.params.grant_type}")
except TokenError as error: except TokenError as error:
return TokenResponse(error.create_dict(), status=400) return TokenResponse(error.create_dict(), status=400)
except UserAuthError as error: except UserAuthError as error:
return TokenResponse(error.create_dict(), status=403) return TokenResponse(error.create_dict(), status=403)
def create_code_response_dic(self) -> dict[str, Any]: def create_code_response(self) -> dict[str, Any]:
"""See https://tools.ietf.org/html/rfc6749#section-4.1""" """See https://tools.ietf.org/html/rfc6749#section-4.1"""
refresh_token = self.params.authorization_code.provider.create_refresh_token( refresh_token = self.params.authorization_code.provider.create_refresh_token(
@ -211,7 +225,7 @@ class TokenView(View):
# We don't need to store the code anymore. # We don't need to store the code anymore.
self.params.authorization_code.delete() self.params.authorization_code.delete()
response_dict = { return {
"access_token": refresh_token.access_token, "access_token": refresh_token.access_token,
"refresh_token": refresh_token.refresh_token, "refresh_token": refresh_token.refresh_token,
"token_type": "bearer", "token_type": "bearer",
@ -223,9 +237,7 @@ class TokenView(View):
"id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()), "id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()),
} }
return response_dict def create_refresh_response(self) -> dict[str, Any]:
def create_refresh_response_dic(self) -> dict[str, Any]:
"""See https://tools.ietf.org/html/rfc6749#section-6""" """See https://tools.ietf.org/html/rfc6749#section-6"""
unauthorized_scopes = set(self.params.scope) - set( unauthorized_scopes = set(self.params.scope) - set(
@ -256,7 +268,7 @@ class TokenView(View):
# Forget the old token. # Forget the old token.
self.params.refresh_token.delete() self.params.refresh_token.delete()
dic = { return {
"access_token": refresh_token.access_token, "access_token": refresh_token.access_token,
"refresh_token": refresh_token.refresh_token, "refresh_token": refresh_token.refresh_token,
"token_type": "bearer", "token_type": "bearer",
@ -267,5 +279,3 @@ class TokenView(View):
), ),
"id_token": self.params.provider.encode(refresh_token.id_token.to_dict()), "id_token": self.params.provider.encode(refresh_token.id_token.to_dict()),
} }
return dic