providers/oauth2: check redirect_uri before request object

This commit is contained in:
Jens Langhammer 2020-12-27 15:54:49 +01:00
parent 7f7b7e37c1
commit b04c9a2098
2 changed files with 31 additions and 28 deletions

View file

@ -6,7 +6,6 @@ from typing import List, Optional, Tuple
from django.http import HttpRequest, HttpResponse, JsonResponse from django.http import HttpRequest, HttpResponse, JsonResponse
from django.utils.cache import patch_vary_headers from django.utils.cache import patch_vary_headers
from jwkest.jwt import JWT
from structlog import get_logger from structlog import get_logger
from authentik.providers.oauth2.errors import BearerTokenError from authentik.providers.oauth2.errors import BearerTokenError
@ -140,17 +139,3 @@ def protected_resource_view(scopes: List[str]):
return view_wrapper return view_wrapper
return wrapper return wrapper
def client_id_from_id_token(id_token):
"""
Extracts the client id from a JSON Web Token (JWT).
Returns a string or None.
"""
payload = JWT().unpack(id_token).payload()
aud = payload.get("aud", None)
if aud is None:
return None
if isinstance(aud, list):
return aud[0]
return aud

View file

@ -62,6 +62,7 @@ ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSNET, PROMPT_LOGIN}
@dataclass @dataclass
# pylint: disable=too-many-instance-attributes
class OAuthAuthorizationParams: class OAuthAuthorizationParams:
"""Parameteres required to authorize an OAuth Client""" """Parameteres required to authorize an OAuth Client"""
@ -76,6 +77,8 @@ class OAuthAuthorizationParams:
provider: OAuth2Provider = field(default_factory=OAuth2Provider) provider: OAuth2Provider = field(default_factory=OAuth2Provider)
request: Optional[str] = None
max_age: Optional[int] = None max_age: Optional[int] = None
code_challenge: Optional[str] = None code_challenge: Optional[str] = None
@ -118,11 +121,6 @@ class OAuthAuthorizationParams:
LOGGER.warning("Invalid response type", type=response_type) LOGGER.warning("Invalid response type", type=response_type)
raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state) raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state)
if "request" in query_dict:
raise AuthorizeError(
redirect_uri, "request_not_supported", grant_type, state
)
max_age = query_dict.get("max_age") max_age = query_dict.get("max_age")
return OAuthAuthorizationParams( return OAuthAuthorizationParams(
client_id=query_dict.get("client_id", ""), client_id=query_dict.get("client_id", ""),
@ -135,6 +133,7 @@ class OAuthAuthorizationParams:
prompt=ALLOWED_PROMPT_PARAMS.intersection( prompt=ALLOWED_PROMPT_PARAMS.intersection(
set(query_dict.get("prompt", "").split()) set(query_dict.get("prompt", "").split())
), ),
request=query_dict.get("request", None),
max_age=int(max_age) if max_age else None, max_age=int(max_age) if max_age else None,
code_challenge=query_dict.get("code_challenge"), code_challenge=query_dict.get("code_challenge"),
code_challenge_method=query_dict.get("code_challenge_method"), code_challenge_method=query_dict.get("code_challenge_method"),
@ -148,9 +147,14 @@ class OAuthAuthorizationParams:
except OAuth2Provider.DoesNotExist: except OAuth2Provider.DoesNotExist:
LOGGER.warning("Invalid client identifier", client_id=self.client_id) LOGGER.warning("Invalid client identifier", client_id=self.client_id)
raise ClientIdError() raise ClientIdError()
is_open_id = SCOPE_OPENID in self.scope self.check_redirect_uri()
self.check_scope()
self.check_nonce()
self.check_response_type()
self.check_code_challenge()
# Redirect URI validation. def check_redirect_uri(self):
"""Redirect URI validation."""
if not self.redirect_uri: if not self.redirect_uri:
LOGGER.warning("Missing redirect uri.") LOGGER.warning("Missing redirect uri.")
raise RedirectUriError() raise RedirectUriError()
@ -171,7 +175,14 @@ class OAuthAuthorizationParams:
) )
raise RedirectUriError() raise RedirectUriError()
if not is_open_id and ( if self.request:
raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
def check_scope(self):
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
if SCOPE_OPENID not in self.scope and (
self.grant_type == GrantTypes.HYBRID self.grant_type == GrantTypes.HYBRID
or self.response_type or self.response_type
in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN] in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
@ -181,14 +192,20 @@ class OAuthAuthorizationParams:
self.redirect_uri, "invalid_scope", self.grant_type, self.state self.redirect_uri, "invalid_scope", self.grant_type, self.state
) )
# Nonce parameter validation. def check_nonce(self):
if is_open_id and self.grant_type == GrantTypes.IMPLICIT and not self.nonce: """Nonce parameter validation."""
if (
SCOPE_OPENID in self.scope
and self.grant_type == GrantTypes.IMPLICIT
and not self.nonce
):
raise AuthorizeError( raise AuthorizeError(
self.redirect_uri, "invalid_request", self.grant_type, self.state self.redirect_uri, "invalid_request", self.grant_type, self.state
) )
# Response type parameter validation. def check_response_type(self):
if is_open_id: """Response type parameter validation."""
if SCOPE_OPENID in self.scope:
actual_response_type = self.provider.response_type actual_response_type = self.provider.response_type
if "#" in self.provider.response_type: if "#" in self.provider.response_type:
hash_index = actual_response_type.index("#") hash_index = actual_response_type.index("#")
@ -198,7 +215,8 @@ class OAuthAuthorizationParams:
self.redirect_uri, "invalid_request", self.grant_type, self.state self.redirect_uri, "invalid_request", self.grant_type, self.state
) )
# PKCE validation of the transformation method. def check_code_challenge(self):
"""PKCE validation of the transformation method."""
if self.code_challenge: if self.code_challenge:
if not (self.code_challenge_method in ["plain", "S256"]): if not (self.code_challenge_method in ["plain", "S256"]):
raise AuthorizeError( raise AuthorizeError(