providers/oauth2: redirect back correctly with state on AuthorizationError
This commit is contained in:
parent
55322995a1
commit
bcd0686a33
|
@ -1,7 +1,9 @@
|
||||||
"""OAuth errors"""
|
"""OAuth errors"""
|
||||||
|
from typing import Optional
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
from authentik.lib.sentry import SentryIgnoredException
|
||||||
|
from authentik.providers.oauth2.models import GrantTypes
|
||||||
|
|
||||||
|
|
||||||
class OAuth2Error(SentryIgnoredException):
|
class OAuth2Error(SentryIgnoredException):
|
||||||
|
@ -98,27 +100,34 @@ class AuthorizeError(OAuth2Error):
|
||||||
"the registration parameter",
|
"the registration parameter",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, redirect_uri, error, grant_type):
|
def __init__(
|
||||||
|
self,
|
||||||
|
redirect_uri: str,
|
||||||
|
error: str,
|
||||||
|
grant_type: str,
|
||||||
|
state: Optional[str] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.error = error
|
self.error = error
|
||||||
self.description = self._errors[error]
|
self.description = self._errors[error]
|
||||||
self.redirect_uri = redirect_uri
|
self.redirect_uri = redirect_uri
|
||||||
self.grant_type = grant_type
|
self.grant_type = grant_type
|
||||||
|
self.state = state
|
||||||
|
|
||||||
def create_uri(self, redirect_uri: str, state: str) -> str:
|
def create_uri(self, redirect_uri: str) -> str:
|
||||||
"""Get a redirect URI with the error message"""
|
"""Get a redirect URI with the error message"""
|
||||||
description = quote(str(self.description))
|
description = quote(str(self.description))
|
||||||
|
|
||||||
# See:
|
# See:
|
||||||
# http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError
|
# http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError
|
||||||
hash_or_question = "#" if self.grant_type == "implicit" else "?"
|
hash_or_question = "#" if self.grant_type == GrantTypes.IMPLICIT else "?"
|
||||||
|
|
||||||
uri = "{0}{1}error={2}&error_description={3}".format(
|
uri = "{0}{1}error={2}&error_description={3}".format(
|
||||||
redirect_uri, hash_or_question, self.error, description
|
redirect_uri, hash_or_question, self.error, description
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add state if present.
|
# Add state if present.
|
||||||
uri = uri + ("&state={0}".format(state) if state else "")
|
uri = uri + ("&state={0}".format(self.state) if self.state else "")
|
||||||
|
|
||||||
return uri
|
return uri
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,8 @@ class OAuthAuthorizationParams:
|
||||||
# Because in this endpoint we handle both GET
|
# Because in this endpoint we handle both GET
|
||||||
# and POST request.
|
# and POST request.
|
||||||
query_dict = request.POST if request.method == "POST" else request.GET
|
query_dict = request.POST if request.method == "POST" else request.GET
|
||||||
|
state = query_dict.get("state", "")
|
||||||
|
redirect_uri = query_dict.get("redirect_uri", "")
|
||||||
|
|
||||||
response_type = query_dict.get("response_type", "")
|
response_type = query_dict.get("response_type", "")
|
||||||
grant_type = None
|
grant_type = None
|
||||||
|
@ -113,20 +115,21 @@ class OAuthAuthorizationParams:
|
||||||
# Grant type validation.
|
# Grant type validation.
|
||||||
if not grant_type:
|
if not grant_type:
|
||||||
LOGGER.warning("Invalid response type", type=response_type)
|
LOGGER.warning("Invalid response type", type=response_type)
|
||||||
|
raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state)
|
||||||
|
|
||||||
|
if "request" in query_dict:
|
||||||
raise AuthorizeError(
|
raise AuthorizeError(
|
||||||
query_dict.get("redirect_uri", ""),
|
redirect_uri, "request_not_supported", grant_type, state
|
||||||
"unsupported_response_type",
|
|
||||||
grant_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
max_age = query_dict.get("max_age")
|
max_age = query_dict.get("max_age")
|
||||||
return OAuthAuthorizationParams(
|
return OAuthAuthorizationParams(
|
||||||
client_id=query_dict.get("client_id", ""),
|
client_id=query_dict.get("client_id", ""),
|
||||||
redirect_uri=query_dict.get("redirect_uri", ""),
|
redirect_uri=redirect_uri,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
grant_type=grant_type,
|
grant_type=grant_type,
|
||||||
scope=query_dict.get("scope", "").split(),
|
scope=query_dict.get("scope", "").split(),
|
||||||
state=query_dict.get("state", ""),
|
state=state,
|
||||||
nonce=query_dict.get("nonce", ""),
|
nonce=query_dict.get("nonce", ""),
|
||||||
prompt=ALLOWED_PROMPT_PARAMS.intersection(
|
prompt=ALLOWED_PROMPT_PARAMS.intersection(
|
||||||
set(query_dict.get("prompt", "").split())
|
set(query_dict.get("prompt", "").split())
|
||||||
|
@ -253,7 +256,7 @@ class OAuthFulfillmentStage(StageView):
|
||||||
return bad_request_message(request, error.description, title=error.error)
|
return bad_request_message(request, error.description, title=error.error)
|
||||||
except AuthorizeError as error:
|
except AuthorizeError as error:
|
||||||
self.executor.stage_invalid()
|
self.executor.stage_invalid()
|
||||||
uri = error.create_uri(self.params.redirect_uri, self.params.state)
|
uri = error.create_uri(self.params.redirect_uri)
|
||||||
return redirect(uri)
|
return redirect(uri)
|
||||||
|
|
||||||
def create_response_uri(self) -> str:
|
def create_response_uri(self) -> str:
|
||||||
|
@ -332,7 +335,10 @@ class OAuthFulfillmentStage(StageView):
|
||||||
except OAuth2Error as error:
|
except OAuth2Error as error:
|
||||||
LOGGER.exception("Error when trying to create response uri", error=error)
|
LOGGER.exception("Error when trying to create response uri", error=error)
|
||||||
raise AuthorizeError(
|
raise AuthorizeError(
|
||||||
self.params.redirect_uri, "server_error", self.params.grant_type
|
self.params.redirect_uri,
|
||||||
|
"server_error",
|
||||||
|
self.params.grant_type,
|
||||||
|
self.params.state,
|
||||||
)
|
)
|
||||||
|
|
||||||
uri = uri._replace(
|
uri = uri._replace(
|
||||||
|
@ -353,6 +359,8 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
||||||
see https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.2.6"""
|
see https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.2.6"""
|
||||||
try:
|
try:
|
||||||
self.params = OAuthAuthorizationParams.from_request(self.request)
|
self.params = OAuthAuthorizationParams.from_request(self.request)
|
||||||
|
except AuthorizeError as error:
|
||||||
|
raise RequestValidationError(redirect(error.create_uri(error.redirect_uri)))
|
||||||
except OAuth2Error as error:
|
except OAuth2Error as error:
|
||||||
raise RequestValidationError(
|
raise RequestValidationError(
|
||||||
bad_request_message(self.request, error.description, title=error.error)
|
bad_request_message(self.request, error.description, title=error.error)
|
||||||
|
@ -365,7 +373,7 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
||||||
self.params.redirect_uri, "login_required", self.params.grant_type
|
self.params.redirect_uri, "login_required", self.params.grant_type
|
||||||
)
|
)
|
||||||
raise RequestValidationError(
|
raise RequestValidationError(
|
||||||
redirect(error.create_uri(self.params.redirect_uri, self.params.state))
|
redirect(error.create_uri(self.params.redirect_uri))
|
||||||
)
|
)
|
||||||
|
|
||||||
def resolve_provider_application(self):
|
def resolve_provider_application(self):
|
||||||
|
|
Reference in a new issue