sources/oauth: correctly concatenate URLs to allow custom parameters to be included

closes #3374

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-08-08 21:17:10 +02:00
parent 6356ddd9f3
commit 4c9878313c
6 changed files with 33 additions and 40 deletions

View file

@ -228,9 +228,9 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
return DEFAULT_AVATAR
if mode.startswith("attributes."):
return get_path_from_dict(self.attributes, mode[11:], default=DEFAULT_AVATAR)
# gravatar uses md5 for their URLs, so md5 can't be avoided
mail_hash = md5(self.email.lower().encode("utf-8")).hexdigest() # nosec
if mode == "gravatar":
# gravatar uses md5 for their URLs, so md5 can't be avoided
parameters = [
("s", "158"),
("r", "g"),

View file

@ -1,6 +1,6 @@
"""OAuth Clients"""
from typing import Any, Optional
from urllib.parse import quote, urlencode
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse
from django.http import HttpRequest
from requests import Session
@ -32,11 +32,11 @@ class BaseOAuthClient:
self.callback = callback
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request."
"""Fetch access token from callback request."""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information."
"""Fetch user profile information."""
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
@ -50,19 +50,11 @@ class BaseOAuthClient:
return response.json()
def get_redirect_args(self) -> dict[str, str]:
"Get request parameters for redirect url."
"""Get request parameters for redirect url."""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_redirect_url(self, parameters=None):
"Build authentication redirect url."
args = self.get_redirect_args()
additional = parameters or {}
args.update(additional)
# Special handling for scope, since it's set as array
# to make additional scopes easier
args["scope"] = " ".join(sorted(set(args["scope"])))
params = urlencode(args, quote_via=quote)
LOGGER.info("redirect args", **args)
"""Build authentication redirect url."""
authorization_url = self.source.type.authorization_url or ""
if self.source.type.urls_customizable and self.source.authorization_url:
authorization_url = self.source.authorization_url
@ -72,10 +64,20 @@ class BaseOAuthClient:
source=self.source,
message="Source has an empty authorization URL.",
).save()
return f"{authorization_url}?{params}"
parsed_url = urlparse(authorization_url)
parsed_args = parse_qs(parsed_url.query)
args = self.get_redirect_args()
args.update(parameters or {})
args.update(parsed_args)
# Special handling for scope, since it's set as array
# to make additional scopes easier
args["scope"] = " ".join(sorted(set(args["scope"])))
params = urlencode(args, quote_via=quote)
LOGGER.info("redirect args", **args)
return urlunparse(parsed_url._replace(query=params))
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response."
"""Parse token and secret from raw token response."""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def do_request(self, method: str, url: str, **kwargs) -> Response:

View file

@ -21,7 +21,7 @@ class OAuthClient(BaseOAuthClient):
}
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request."
"""Fetch access token from callback request."""
raw_token = self.request.session.get(self.session_key, None)
verifier = self.request.GET.get("oauth_verifier", None)
callback = self.request.build_absolute_uri(self.callback)
@ -48,7 +48,7 @@ class OAuthClient(BaseOAuthClient):
return None
def get_request_token(self) -> str:
"Fetch the OAuth request token. Only required for OAuth 1.0."
"""Fetch the OAuth request token. Only required for OAuth 1.0."""
callback = self.request.build_absolute_uri(self.callback)
try:
request_token_url = self.source.type.request_token_url or ""
@ -67,7 +67,7 @@ class OAuthClient(BaseOAuthClient):
return response.text
def get_redirect_args(self) -> dict[str, Any]:
"Get request parameters for redirect url."
"""Get request parameters for redirect url."""
callback = self.request.build_absolute_uri(self.callback)
raw_token = self.get_request_token()
token = self.parse_raw_token(raw_token)
@ -78,11 +78,11 @@ class OAuthClient(BaseOAuthClient):
}
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response."
"""Parse token and secret from raw token response."""
return dict(parse_qsl(raw_token))
def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth."
"""Build remote url request. Constructs necessary auth."""
resource_owner_key = None
resource_owner_secret = None
if "token" in kwargs:

View file

@ -28,7 +28,7 @@ class OAuth2Client(BaseOAuthClient):
return self.request.GET.get(key, default)
def check_application_state(self) -> bool:
"Check optional state parameter."
"""Check optional state parameter."""
stored = self.request.session.get(self.session_key, None)
returned = self.get_request_arg("state", None)
check = False
@ -42,7 +42,7 @@ class OAuth2Client(BaseOAuthClient):
return check
def get_application_state(self) -> str:
"Generate state optional parameter."
"""Generate state optional parameter."""
return get_random_string(32)
def get_client_id(self) -> str:
@ -54,7 +54,7 @@ class OAuth2Client(BaseOAuthClient):
return self.source.consumer_secret
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request."
"""Fetch access token from callback request."""
callback = self.request.build_absolute_uri(self.callback or self.request.path)
if not self.check_application_state():
LOGGER.warning("Application state check failed.")
@ -87,7 +87,7 @@ class OAuth2Client(BaseOAuthClient):
return response.json()
def get_redirect_args(self) -> dict[str, str]:
"Get request parameters for redirect url."
"""Get request parameters for redirect url."""
callback = self.request.build_absolute_uri(self.callback)
client_id: str = self.get_client_id()
args: dict[str, str] = {
@ -102,7 +102,7 @@ class OAuth2Client(BaseOAuthClient):
return args
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response."
"""Parse token and secret from raw token response."""
# Load as json first then parse as query string
try:
token_data = loads(raw_token)
@ -112,7 +112,7 @@ class OAuth2Client(BaseOAuthClient):
return token_data
def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth."
"""Build remote url request. Constructs necessary auth."""
if "token" in kwargs:
token = kwargs.pop("token")

View file

@ -1,5 +1,5 @@
"""saml sp views"""
from urllib.parse import ParseResult, parse_qsl, urlparse, urlunparse
from urllib.parse import parse_qsl, urlparse, urlunparse
from django.contrib.auth import logout
from django.contrib.auth.mixins import LoginRequiredMixin
@ -112,17 +112,8 @@ class InitiateView(View):
url_kwargs = dict(parse_qsl(sso_url.query))
# ... and update it with the SAML args
url_kwargs.update(auth_n_req.build_auth_n_detached())
# Encode it back into a string
res = ParseResult(
scheme=sso_url.scheme,
netloc=sso_url.netloc,
path=sso_url.path,
params=sso_url.params,
query=urlencode(url_kwargs),
fragment=sso_url.fragment,
)
# and merge it back into a URL
final_url = urlunparse(res)
# Update the url
final_url = urlunparse(sso_url._replace(query=urlencode(url_kwargs)))
return redirect(final_url)
# As POST Binding we show a form
try:

View file

@ -27,7 +27,7 @@ func (ws *WebServer) APISentryProxy(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusBadRequest)
return
}
lines := strings.Split(string(fb.Bytes()), "\n")
lines := strings.Split(fb.String(), "\n")
if len(lines) < 1 {
rw.WriteHeader(http.StatusBadRequest)
return