sources/oauth: correctly concatenate URLs to allow custom parameters to be included

closes #3374

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-08-08 21:17:10 +02:00
parent 6356ddd9f3
commit 4c9878313c
6 changed files with 33 additions and 40 deletions

View File

@ -228,9 +228,9 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
return DEFAULT_AVATAR return DEFAULT_AVATAR
if mode.startswith("attributes."): if mode.startswith("attributes."):
return get_path_from_dict(self.attributes, mode[11:], default=DEFAULT_AVATAR) return get_path_from_dict(self.attributes, mode[11:], default=DEFAULT_AVATAR)
# gravatar uses md5 for their URLs, so md5 can't be avoided
mail_hash = md5(self.email.lower().encode("utf-8")).hexdigest() # nosec mail_hash = md5(self.email.lower().encode("utf-8")).hexdigest() # nosec
if mode == "gravatar": if mode == "gravatar":
# gravatar uses md5 for their URLs, so md5 can't be avoided
parameters = [ parameters = [
("s", "158"), ("s", "158"),
("r", "g"), ("r", "g"),

View File

@ -1,6 +1,6 @@
"""OAuth Clients""" """OAuth Clients"""
from typing import Any, Optional from typing import Any, Optional
from urllib.parse import quote, urlencode from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse
from django.http import HttpRequest from django.http import HttpRequest
from requests import Session from requests import Session
@ -32,11 +32,11 @@ class BaseOAuthClient:
self.callback = callback self.callback = callback
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request." """Fetch access token from callback request."""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information." """Fetch user profile information."""
profile_url = self.source.type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
@ -50,19 +50,11 @@ class BaseOAuthClient:
return response.json() return response.json()
def get_redirect_args(self) -> dict[str, str]: def get_redirect_args(self) -> dict[str, str]:
"Get request parameters for redirect url." """Get request parameters for redirect url."""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_redirect_url(self, parameters=None): def get_redirect_url(self, parameters=None):
"Build authentication redirect url." """Build authentication redirect url."""
args = self.get_redirect_args()
additional = parameters or {}
args.update(additional)
# Special handling for scope, since it's set as array
# to make additional scopes easier
args["scope"] = " ".join(sorted(set(args["scope"])))
params = urlencode(args, quote_via=quote)
LOGGER.info("redirect args", **args)
authorization_url = self.source.type.authorization_url or "" authorization_url = self.source.type.authorization_url or ""
if self.source.type.urls_customizable and self.source.authorization_url: if self.source.type.urls_customizable and self.source.authorization_url:
authorization_url = self.source.authorization_url authorization_url = self.source.authorization_url
@ -72,10 +64,20 @@ class BaseOAuthClient:
source=self.source, source=self.source,
message="Source has an empty authorization URL.", message="Source has an empty authorization URL.",
).save() ).save()
return f"{authorization_url}?{params}" parsed_url = urlparse(authorization_url)
parsed_args = parse_qs(parsed_url.query)
args = self.get_redirect_args()
args.update(parameters or {})
args.update(parsed_args)
# Special handling for scope, since it's set as array
# to make additional scopes easier
args["scope"] = " ".join(sorted(set(args["scope"])))
params = urlencode(args, quote_via=quote)
LOGGER.info("redirect args", **args)
return urlunparse(parsed_url._replace(query=params))
def parse_raw_token(self, raw_token: str) -> dict[str, Any]: def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response." """Parse token and secret from raw token response."""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def do_request(self, method: str, url: str, **kwargs) -> Response: def do_request(self, method: str, url: str, **kwargs) -> Response:

View File

@ -21,7 +21,7 @@ class OAuthClient(BaseOAuthClient):
} }
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request." """Fetch access token from callback request."""
raw_token = self.request.session.get(self.session_key, None) raw_token = self.request.session.get(self.session_key, None)
verifier = self.request.GET.get("oauth_verifier", None) verifier = self.request.GET.get("oauth_verifier", None)
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
@ -48,7 +48,7 @@ class OAuthClient(BaseOAuthClient):
return None return None
def get_request_token(self) -> str: def get_request_token(self) -> str:
"Fetch the OAuth request token. Only required for OAuth 1.0." """Fetch the OAuth request token. Only required for OAuth 1.0."""
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
try: try:
request_token_url = self.source.type.request_token_url or "" request_token_url = self.source.type.request_token_url or ""
@ -67,7 +67,7 @@ class OAuthClient(BaseOAuthClient):
return response.text return response.text
def get_redirect_args(self) -> dict[str, Any]: def get_redirect_args(self) -> dict[str, Any]:
"Get request parameters for redirect url." """Get request parameters for redirect url."""
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
raw_token = self.get_request_token() raw_token = self.get_request_token()
token = self.parse_raw_token(raw_token) token = self.parse_raw_token(raw_token)
@ -78,11 +78,11 @@ class OAuthClient(BaseOAuthClient):
} }
def parse_raw_token(self, raw_token: str) -> dict[str, Any]: def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response." """Parse token and secret from raw token response."""
return dict(parse_qsl(raw_token)) return dict(parse_qsl(raw_token))
def do_request(self, method: str, url: str, **kwargs) -> Response: def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth." """Build remote url request. Constructs necessary auth."""
resource_owner_key = None resource_owner_key = None
resource_owner_secret = None resource_owner_secret = None
if "token" in kwargs: if "token" in kwargs:

View File

@ -28,7 +28,7 @@ class OAuth2Client(BaseOAuthClient):
return self.request.GET.get(key, default) return self.request.GET.get(key, default)
def check_application_state(self) -> bool: def check_application_state(self) -> bool:
"Check optional state parameter." """Check optional state parameter."""
stored = self.request.session.get(self.session_key, None) stored = self.request.session.get(self.session_key, None)
returned = self.get_request_arg("state", None) returned = self.get_request_arg("state", None)
check = False check = False
@ -42,7 +42,7 @@ class OAuth2Client(BaseOAuthClient):
return check return check
def get_application_state(self) -> str: def get_application_state(self) -> str:
"Generate state optional parameter." """Generate state optional parameter."""
return get_random_string(32) return get_random_string(32)
def get_client_id(self) -> str: def get_client_id(self) -> str:
@ -54,7 +54,7 @@ class OAuth2Client(BaseOAuthClient):
return self.source.consumer_secret return self.source.consumer_secret
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request." """Fetch access token from callback request."""
callback = self.request.build_absolute_uri(self.callback or self.request.path) callback = self.request.build_absolute_uri(self.callback or self.request.path)
if not self.check_application_state(): if not self.check_application_state():
LOGGER.warning("Application state check failed.") LOGGER.warning("Application state check failed.")
@ -87,7 +87,7 @@ class OAuth2Client(BaseOAuthClient):
return response.json() return response.json()
def get_redirect_args(self) -> dict[str, str]: def get_redirect_args(self) -> dict[str, str]:
"Get request parameters for redirect url." """Get request parameters for redirect url."""
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
client_id: str = self.get_client_id() client_id: str = self.get_client_id()
args: dict[str, str] = { args: dict[str, str] = {
@ -102,7 +102,7 @@ class OAuth2Client(BaseOAuthClient):
return args return args
def parse_raw_token(self, raw_token: str) -> dict[str, Any]: def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response." """Parse token and secret from raw token response."""
# Load as json first then parse as query string # Load as json first then parse as query string
try: try:
token_data = loads(raw_token) token_data = loads(raw_token)
@ -112,7 +112,7 @@ class OAuth2Client(BaseOAuthClient):
return token_data return token_data
def do_request(self, method: str, url: str, **kwargs) -> Response: def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth." """Build remote url request. Constructs necessary auth."""
if "token" in kwargs: if "token" in kwargs:
token = kwargs.pop("token") token = kwargs.pop("token")

View File

@ -1,5 +1,5 @@
"""saml sp views""" """saml sp views"""
from urllib.parse import ParseResult, parse_qsl, urlparse, urlunparse from urllib.parse import parse_qsl, urlparse, urlunparse
from django.contrib.auth import logout from django.contrib.auth import logout
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
@ -112,17 +112,8 @@ class InitiateView(View):
url_kwargs = dict(parse_qsl(sso_url.query)) url_kwargs = dict(parse_qsl(sso_url.query))
# ... and update it with the SAML args # ... and update it with the SAML args
url_kwargs.update(auth_n_req.build_auth_n_detached()) url_kwargs.update(auth_n_req.build_auth_n_detached())
# Encode it back into a string # Update the url
res = ParseResult( final_url = urlunparse(sso_url._replace(query=urlencode(url_kwargs)))
scheme=sso_url.scheme,
netloc=sso_url.netloc,
path=sso_url.path,
params=sso_url.params,
query=urlencode(url_kwargs),
fragment=sso_url.fragment,
)
# and merge it back into a URL
final_url = urlunparse(res)
return redirect(final_url) return redirect(final_url)
# As POST Binding we show a form # As POST Binding we show a form
try: try:

View File

@ -27,7 +27,7 @@ func (ws *WebServer) APISentryProxy(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusBadRequest) rw.WriteHeader(http.StatusBadRequest)
return return
} }
lines := strings.Split(string(fb.Bytes()), "\n") lines := strings.Split(fb.String(), "\n")
if len(lines) < 1 { if len(lines) < 1 {
rw.WriteHeader(http.StatusBadRequest) rw.WriteHeader(http.StatusBadRequest)
return return