providers/oauth2: fix CORS headers not being set for unsuccessful requests
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
84ec70c2a2
commit
3e60e956f4
|
@ -50,7 +50,7 @@ def cors_allow(request: HttpRequest, response: HttpResponse, *allowed_origins: s
|
||||||
if not allowed:
|
if not allowed:
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
"CORS: Origin is not an allowed origin",
|
"CORS: Origin is not an allowed origin",
|
||||||
requested=origin,
|
requested=received_origin,
|
||||||
allowed=allowed_origins,
|
allowed=allowed_origins,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -30,6 +30,7 @@ LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# pylint: disable=too-many-instance-attributes
|
||||||
class TokenParams:
|
class TokenParams:
|
||||||
"""Token params"""
|
"""Token params"""
|
||||||
|
|
||||||
|
@ -40,6 +41,8 @@ class TokenParams:
|
||||||
state: str
|
state: str
|
||||||
scope: list[str]
|
scope: list[str]
|
||||||
|
|
||||||
|
provider: OAuth2Provider
|
||||||
|
|
||||||
authorization_code: Optional[AuthorizationCode] = None
|
authorization_code: Optional[AuthorizationCode] = None
|
||||||
refresh_token: Optional[RefreshToken] = None
|
refresh_token: Optional[RefreshToken] = None
|
||||||
|
|
||||||
|
@ -47,35 +50,33 @@ class TokenParams:
|
||||||
|
|
||||||
raw_code: InitVar[str] = ""
|
raw_code: InitVar[str] = ""
|
||||||
raw_token: InitVar[str] = ""
|
raw_token: InitVar[str] = ""
|
||||||
|
request: InitVar[Optional[HttpRequest]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_request(request: HttpRequest) -> "TokenParams":
|
def parse(
|
||||||
"""Extract Token Parameters from http request"""
|
request: HttpRequest,
|
||||||
client_id, client_secret = extract_client_auth(request)
|
provider: OAuth2Provider,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
) -> "TokenParams":
|
||||||
return TokenParams(
|
return TokenParams(
|
||||||
|
# Init vars
|
||||||
|
raw_code=request.POST.get("code", ""),
|
||||||
|
raw_token=request.POST.get("refresh_token", ""),
|
||||||
|
request=request,
|
||||||
|
# Regular params
|
||||||
|
provider=provider,
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
client_secret=client_secret,
|
client_secret=client_secret,
|
||||||
redirect_uri=request.POST.get("redirect_uri", ""),
|
redirect_uri=request.POST.get("redirect_uri", ""),
|
||||||
grant_type=request.POST.get("grant_type", ""),
|
grant_type=request.POST.get("grant_type", ""),
|
||||||
raw_code=request.POST.get("code", ""),
|
|
||||||
raw_token=request.POST.get("refresh_token", ""),
|
|
||||||
state=request.POST.get("state", ""),
|
state=request.POST.get("state", ""),
|
||||||
scope=request.POST.get("scope", "").split(),
|
scope=request.POST.get("scope", "").split(),
|
||||||
# PKCE parameter.
|
# PKCE parameter.
|
||||||
code_verifier=request.POST.get("code_verifier"),
|
code_verifier=request.POST.get("code_verifier"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self, raw_code, raw_token):
|
def __post_init__(self, raw_code: str, raw_token: str, request: HttpRequest):
|
||||||
try:
|
|
||||||
provider: OAuth2Provider = OAuth2Provider.objects.get(
|
|
||||||
client_id=self.client_id
|
|
||||||
)
|
|
||||||
self.provider = provider
|
|
||||||
except OAuth2Provider.DoesNotExist:
|
|
||||||
LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
|
|
||||||
raise TokenError("invalid_client")
|
|
||||||
|
|
||||||
if self.provider.client_type == ClientTypes.CONFIDENTIAL:
|
if self.provider.client_type == ClientTypes.CONFIDENTIAL:
|
||||||
if self.provider.client_secret != self.client_secret:
|
if self.provider.client_secret != self.client_secret:
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
|
@ -87,7 +88,6 @@ class TokenParams:
|
||||||
|
|
||||||
if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
|
if self.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
|
||||||
self.__post_init_code(raw_code)
|
self.__post_init_code(raw_code)
|
||||||
|
|
||||||
elif self.grant_type == GRANT_TYPE_REFRESH_TOKEN:
|
elif self.grant_type == GRANT_TYPE_REFRESH_TOKEN:
|
||||||
if not raw_token:
|
if not raw_token:
|
||||||
LOGGER.warning("Missing refresh token")
|
LOGGER.warning("Missing refresh token")
|
||||||
|
@ -159,13 +159,14 @@ class TokenParams:
|
||||||
class TokenView(View):
|
class TokenView(View):
|
||||||
"""Generate tokens for clients"""
|
"""Generate tokens for clients"""
|
||||||
|
|
||||||
|
provider: Optional[OAuth2Provider] = None
|
||||||
params: Optional[TokenParams] = None
|
params: Optional[TokenParams] = None
|
||||||
|
|
||||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||||
response = super().dispatch(request, *args, **kwargs)
|
response = super().dispatch(request, *args, **kwargs)
|
||||||
allowed_origins = []
|
allowed_origins = []
|
||||||
if self.params:
|
if self.provider:
|
||||||
allowed_origins = self.params.provider.redirect_uris.split("\n")
|
allowed_origins = self.provider.redirect_uris.split("\n")
|
||||||
cors_allow(self.request, response, *allowed_origins)
|
cors_allow(self.request, response, *allowed_origins)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -175,19 +176,32 @@ class TokenView(View):
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
def post(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Generate tokens for clients"""
|
"""Generate tokens for clients"""
|
||||||
try:
|
try:
|
||||||
self.params = TokenParams.from_request(request)
|
client_id, client_secret = extract_client_auth(request)
|
||||||
|
try:
|
||||||
|
self.provider: OAuth2Provider = OAuth2Provider.objects.get(
|
||||||
|
client_id=client_id
|
||||||
|
)
|
||||||
|
except OAuth2Provider.DoesNotExist:
|
||||||
|
LOGGER.warning(
|
||||||
|
"OAuth2Provider does not exist", client_id=self.client_id
|
||||||
|
)
|
||||||
|
raise TokenError("invalid_client")
|
||||||
|
|
||||||
|
self.params = TokenParams.parse(
|
||||||
|
request, self.provider, client_id, client_secret
|
||||||
|
)
|
||||||
|
|
||||||
if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
|
if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE:
|
||||||
return TokenResponse(self.create_code_response_dic())
|
return TokenResponse(self.create_code_response())
|
||||||
if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN:
|
if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN:
|
||||||
return TokenResponse(self.create_refresh_response_dic())
|
return TokenResponse(self.create_refresh_response())
|
||||||
raise ValueError(f"Invalid grant_type: {self.params.grant_type}")
|
raise ValueError(f"Invalid grant_type: {self.params.grant_type}")
|
||||||
except TokenError as error:
|
except TokenError as error:
|
||||||
return TokenResponse(error.create_dict(), status=400)
|
return TokenResponse(error.create_dict(), status=400)
|
||||||
except UserAuthError as error:
|
except UserAuthError as error:
|
||||||
return TokenResponse(error.create_dict(), status=403)
|
return TokenResponse(error.create_dict(), status=403)
|
||||||
|
|
||||||
def create_code_response_dic(self) -> dict[str, Any]:
|
def create_code_response(self) -> dict[str, Any]:
|
||||||
"""See https://tools.ietf.org/html/rfc6749#section-4.1"""
|
"""See https://tools.ietf.org/html/rfc6749#section-4.1"""
|
||||||
|
|
||||||
refresh_token = self.params.authorization_code.provider.create_refresh_token(
|
refresh_token = self.params.authorization_code.provider.create_refresh_token(
|
||||||
|
@ -211,7 +225,7 @@ class TokenView(View):
|
||||||
# We don't need to store the code anymore.
|
# We don't need to store the code anymore.
|
||||||
self.params.authorization_code.delete()
|
self.params.authorization_code.delete()
|
||||||
|
|
||||||
response_dict = {
|
return {
|
||||||
"access_token": refresh_token.access_token,
|
"access_token": refresh_token.access_token,
|
||||||
"refresh_token": refresh_token.refresh_token,
|
"refresh_token": refresh_token.refresh_token,
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
|
@ -223,9 +237,7 @@ class TokenView(View):
|
||||||
"id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()),
|
"id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()),
|
||||||
}
|
}
|
||||||
|
|
||||||
return response_dict
|
def create_refresh_response(self) -> dict[str, Any]:
|
||||||
|
|
||||||
def create_refresh_response_dic(self) -> dict[str, Any]:
|
|
||||||
"""See https://tools.ietf.org/html/rfc6749#section-6"""
|
"""See https://tools.ietf.org/html/rfc6749#section-6"""
|
||||||
|
|
||||||
unauthorized_scopes = set(self.params.scope) - set(
|
unauthorized_scopes = set(self.params.scope) - set(
|
||||||
|
@ -256,7 +268,7 @@ class TokenView(View):
|
||||||
# Forget the old token.
|
# Forget the old token.
|
||||||
self.params.refresh_token.delete()
|
self.params.refresh_token.delete()
|
||||||
|
|
||||||
dic = {
|
return {
|
||||||
"access_token": refresh_token.access_token,
|
"access_token": refresh_token.access_token,
|
||||||
"refresh_token": refresh_token.refresh_token,
|
"refresh_token": refresh_token.refresh_token,
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
|
@ -267,5 +279,3 @@ class TokenView(View):
|
||||||
),
|
),
|
||||||
"id_token": self.params.provider.encode(refresh_token.id_token.to_dict()),
|
"id_token": self.params.provider.encode(refresh_token.id_token.to_dict()),
|
||||||
}
|
}
|
||||||
|
|
||||||
return dic
|
|
||||||
|
|
Reference in a new issue