providers/oauth2: check redirect_uri before request object
This commit is contained in:
parent
7f7b7e37c1
commit
b04c9a2098
|
@ -6,7 +6,6 @@ from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||||
from django.utils.cache import patch_vary_headers
|
from django.utils.cache import patch_vary_headers
|
||||||
from jwkest.jwt import JWT
|
|
||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
|
|
||||||
from authentik.providers.oauth2.errors import BearerTokenError
|
from authentik.providers.oauth2.errors import BearerTokenError
|
||||||
|
@ -140,17 +139,3 @@ def protected_resource_view(scopes: List[str]):
|
||||||
return view_wrapper
|
return view_wrapper
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def client_id_from_id_token(id_token):
|
|
||||||
"""
|
|
||||||
Extracts the client id from a JSON Web Token (JWT).
|
|
||||||
Returns a string or None.
|
|
||||||
"""
|
|
||||||
payload = JWT().unpack(id_token).payload()
|
|
||||||
aud = payload.get("aud", None)
|
|
||||||
if aud is None:
|
|
||||||
return None
|
|
||||||
if isinstance(aud, list):
|
|
||||||
return aud[0]
|
|
||||||
return aud
|
|
||||||
|
|
|
@ -62,6 +62,7 @@ ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSNET, PROMPT_LOGIN}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# pylint: disable=too-many-instance-attributes
|
||||||
class OAuthAuthorizationParams:
|
class OAuthAuthorizationParams:
|
||||||
"""Parameteres required to authorize an OAuth Client"""
|
"""Parameteres required to authorize an OAuth Client"""
|
||||||
|
|
||||||
|
@ -76,6 +77,8 @@ class OAuthAuthorizationParams:
|
||||||
|
|
||||||
provider: OAuth2Provider = field(default_factory=OAuth2Provider)
|
provider: OAuth2Provider = field(default_factory=OAuth2Provider)
|
||||||
|
|
||||||
|
request: Optional[str] = None
|
||||||
|
|
||||||
max_age: Optional[int] = None
|
max_age: Optional[int] = None
|
||||||
|
|
||||||
code_challenge: Optional[str] = None
|
code_challenge: Optional[str] = None
|
||||||
|
@ -118,11 +121,6 @@ class OAuthAuthorizationParams:
|
||||||
LOGGER.warning("Invalid response type", type=response_type)
|
LOGGER.warning("Invalid response type", type=response_type)
|
||||||
raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state)
|
raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state)
|
||||||
|
|
||||||
if "request" in query_dict:
|
|
||||||
raise AuthorizeError(
|
|
||||||
redirect_uri, "request_not_supported", grant_type, state
|
|
||||||
)
|
|
||||||
|
|
||||||
max_age = query_dict.get("max_age")
|
max_age = query_dict.get("max_age")
|
||||||
return OAuthAuthorizationParams(
|
return OAuthAuthorizationParams(
|
||||||
client_id=query_dict.get("client_id", ""),
|
client_id=query_dict.get("client_id", ""),
|
||||||
|
@ -135,6 +133,7 @@ class OAuthAuthorizationParams:
|
||||||
prompt=ALLOWED_PROMPT_PARAMS.intersection(
|
prompt=ALLOWED_PROMPT_PARAMS.intersection(
|
||||||
set(query_dict.get("prompt", "").split())
|
set(query_dict.get("prompt", "").split())
|
||||||
),
|
),
|
||||||
|
request=query_dict.get("request", None),
|
||||||
max_age=int(max_age) if max_age else None,
|
max_age=int(max_age) if max_age else None,
|
||||||
code_challenge=query_dict.get("code_challenge"),
|
code_challenge=query_dict.get("code_challenge"),
|
||||||
code_challenge_method=query_dict.get("code_challenge_method"),
|
code_challenge_method=query_dict.get("code_challenge_method"),
|
||||||
|
@ -148,9 +147,14 @@ class OAuthAuthorizationParams:
|
||||||
except OAuth2Provider.DoesNotExist:
|
except OAuth2Provider.DoesNotExist:
|
||||||
LOGGER.warning("Invalid client identifier", client_id=self.client_id)
|
LOGGER.warning("Invalid client identifier", client_id=self.client_id)
|
||||||
raise ClientIdError()
|
raise ClientIdError()
|
||||||
is_open_id = SCOPE_OPENID in self.scope
|
self.check_redirect_uri()
|
||||||
|
self.check_scope()
|
||||||
|
self.check_nonce()
|
||||||
|
self.check_response_type()
|
||||||
|
self.check_code_challenge()
|
||||||
|
|
||||||
# Redirect URI validation.
|
def check_redirect_uri(self):
|
||||||
|
"""Redirect URI validation."""
|
||||||
if not self.redirect_uri:
|
if not self.redirect_uri:
|
||||||
LOGGER.warning("Missing redirect uri.")
|
LOGGER.warning("Missing redirect uri.")
|
||||||
raise RedirectUriError()
|
raise RedirectUriError()
|
||||||
|
@ -171,7 +175,14 @@ class OAuthAuthorizationParams:
|
||||||
)
|
)
|
||||||
raise RedirectUriError()
|
raise RedirectUriError()
|
||||||
|
|
||||||
if not is_open_id and (
|
if self.request:
|
||||||
|
raise AuthorizeError(
|
||||||
|
self.redirect_uri, "request_not_supported", self.grant_type, self.state
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_scope(self):
|
||||||
|
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
|
||||||
|
if SCOPE_OPENID not in self.scope and (
|
||||||
self.grant_type == GrantTypes.HYBRID
|
self.grant_type == GrantTypes.HYBRID
|
||||||
or self.response_type
|
or self.response_type
|
||||||
in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
||||||
|
@ -181,14 +192,20 @@ class OAuthAuthorizationParams:
|
||||||
self.redirect_uri, "invalid_scope", self.grant_type, self.state
|
self.redirect_uri, "invalid_scope", self.grant_type, self.state
|
||||||
)
|
)
|
||||||
|
|
||||||
# Nonce parameter validation.
|
def check_nonce(self):
|
||||||
if is_open_id and self.grant_type == GrantTypes.IMPLICIT and not self.nonce:
|
"""Nonce parameter validation."""
|
||||||
|
if (
|
||||||
|
SCOPE_OPENID in self.scope
|
||||||
|
and self.grant_type == GrantTypes.IMPLICIT
|
||||||
|
and not self.nonce
|
||||||
|
):
|
||||||
raise AuthorizeError(
|
raise AuthorizeError(
|
||||||
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
||||||
)
|
)
|
||||||
|
|
||||||
# Response type parameter validation.
|
def check_response_type(self):
|
||||||
if is_open_id:
|
"""Response type parameter validation."""
|
||||||
|
if SCOPE_OPENID in self.scope:
|
||||||
actual_response_type = self.provider.response_type
|
actual_response_type = self.provider.response_type
|
||||||
if "#" in self.provider.response_type:
|
if "#" in self.provider.response_type:
|
||||||
hash_index = actual_response_type.index("#")
|
hash_index = actual_response_type.index("#")
|
||||||
|
@ -198,7 +215,8 @@ class OAuthAuthorizationParams:
|
||||||
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
||||||
)
|
)
|
||||||
|
|
||||||
# PKCE validation of the transformation method.
|
def check_code_challenge(self):
|
||||||
|
"""PKCE validation of the transformation method."""
|
||||||
if self.code_challenge:
|
if self.code_challenge:
|
||||||
if not (self.code_challenge_method in ["plain", "S256"]):
|
if not (self.code_challenge_method in ["plain", "S256"]):
|
||||||
raise AuthorizeError(
|
raise AuthorizeError(
|
||||||
|
|
Reference in a new issue