sources/oauth: correctly concatenate URLs to allow custom parameters to be included
closes #3374 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
6356ddd9f3
commit
4c9878313c
|
@ -228,9 +228,9 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
|||
return DEFAULT_AVATAR
|
||||
if mode.startswith("attributes."):
|
||||
return get_path_from_dict(self.attributes, mode[11:], default=DEFAULT_AVATAR)
|
||||
# gravatar uses md5 for their URLs, so md5 can't be avoided
|
||||
mail_hash = md5(self.email.lower().encode("utf-8")).hexdigest() # nosec
|
||||
if mode == "gravatar":
|
||||
# gravatar uses md5 for their URLs, so md5 can't be avoided
|
||||
parameters = [
|
||||
("s", "158"),
|
||||
("r", "g"),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""OAuth Clients"""
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import quote, urlencode
|
||||
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse
|
||||
|
||||
from django.http import HttpRequest
|
||||
from requests import Session
|
||||
|
@ -32,11 +32,11 @@ class BaseOAuthClient:
|
|||
self.callback = callback
|
||||
|
||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||
"Fetch access token from callback request."
|
||||
"""Fetch access token from callback request."""
|
||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||
|
||||
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
||||
"Fetch user profile information."
|
||||
"""Fetch user profile information."""
|
||||
profile_url = self.source.type.profile_url or ""
|
||||
if self.source.type.urls_customizable and self.source.profile_url:
|
||||
profile_url = self.source.profile_url
|
||||
|
@ -50,19 +50,11 @@ class BaseOAuthClient:
|
|||
return response.json()
|
||||
|
||||
def get_redirect_args(self) -> dict[str, str]:
|
||||
"Get request parameters for redirect url."
|
||||
"""Get request parameters for redirect url."""
|
||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||
|
||||
def get_redirect_url(self, parameters=None):
|
||||
"Build authentication redirect url."
|
||||
args = self.get_redirect_args()
|
||||
additional = parameters or {}
|
||||
args.update(additional)
|
||||
# Special handling for scope, since it's set as array
|
||||
# to make additional scopes easier
|
||||
args["scope"] = " ".join(sorted(set(args["scope"])))
|
||||
params = urlencode(args, quote_via=quote)
|
||||
LOGGER.info("redirect args", **args)
|
||||
"""Build authentication redirect url."""
|
||||
authorization_url = self.source.type.authorization_url or ""
|
||||
if self.source.type.urls_customizable and self.source.authorization_url:
|
||||
authorization_url = self.source.authorization_url
|
||||
|
@ -72,10 +64,20 @@ class BaseOAuthClient:
|
|||
source=self.source,
|
||||
message="Source has an empty authorization URL.",
|
||||
).save()
|
||||
return f"{authorization_url}?{params}"
|
||||
parsed_url = urlparse(authorization_url)
|
||||
parsed_args = parse_qs(parsed_url.query)
|
||||
args = self.get_redirect_args()
|
||||
args.update(parameters or {})
|
||||
args.update(parsed_args)
|
||||
# Special handling for scope, since it's set as array
|
||||
# to make additional scopes easier
|
||||
args["scope"] = " ".join(sorted(set(args["scope"])))
|
||||
params = urlencode(args, quote_via=quote)
|
||||
LOGGER.info("redirect args", **args)
|
||||
return urlunparse(parsed_url._replace(query=params))
|
||||
|
||||
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||
"Parse token and secret from raw token response."
|
||||
"""Parse token and secret from raw token response."""
|
||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||
|
||||
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
||||
|
|
|
@ -21,7 +21,7 @@ class OAuthClient(BaseOAuthClient):
|
|||
}
|
||||
|
||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||
"Fetch access token from callback request."
|
||||
"""Fetch access token from callback request."""
|
||||
raw_token = self.request.session.get(self.session_key, None)
|
||||
verifier = self.request.GET.get("oauth_verifier", None)
|
||||
callback = self.request.build_absolute_uri(self.callback)
|
||||
|
@ -48,7 +48,7 @@ class OAuthClient(BaseOAuthClient):
|
|||
return None
|
||||
|
||||
def get_request_token(self) -> str:
|
||||
"Fetch the OAuth request token. Only required for OAuth 1.0."
|
||||
"""Fetch the OAuth request token. Only required for OAuth 1.0."""
|
||||
callback = self.request.build_absolute_uri(self.callback)
|
||||
try:
|
||||
request_token_url = self.source.type.request_token_url or ""
|
||||
|
@ -67,7 +67,7 @@ class OAuthClient(BaseOAuthClient):
|
|||
return response.text
|
||||
|
||||
def get_redirect_args(self) -> dict[str, Any]:
|
||||
"Get request parameters for redirect url."
|
||||
"""Get request parameters for redirect url."""
|
||||
callback = self.request.build_absolute_uri(self.callback)
|
||||
raw_token = self.get_request_token()
|
||||
token = self.parse_raw_token(raw_token)
|
||||
|
@ -78,11 +78,11 @@ class OAuthClient(BaseOAuthClient):
|
|||
}
|
||||
|
||||
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||
"Parse token and secret from raw token response."
|
||||
"""Parse token and secret from raw token response."""
|
||||
return dict(parse_qsl(raw_token))
|
||||
|
||||
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
||||
"Build remote url request. Constructs necessary auth."
|
||||
"""Build remote url request. Constructs necessary auth."""
|
||||
resource_owner_key = None
|
||||
resource_owner_secret = None
|
||||
if "token" in kwargs:
|
||||
|
|
|
@ -28,7 +28,7 @@ class OAuth2Client(BaseOAuthClient):
|
|||
return self.request.GET.get(key, default)
|
||||
|
||||
def check_application_state(self) -> bool:
|
||||
"Check optional state parameter."
|
||||
"""Check optional state parameter."""
|
||||
stored = self.request.session.get(self.session_key, None)
|
||||
returned = self.get_request_arg("state", None)
|
||||
check = False
|
||||
|
@ -42,7 +42,7 @@ class OAuth2Client(BaseOAuthClient):
|
|||
return check
|
||||
|
||||
def get_application_state(self) -> str:
|
||||
"Generate state optional parameter."
|
||||
"""Generate state optional parameter."""
|
||||
return get_random_string(32)
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
|
@ -54,7 +54,7 @@ class OAuth2Client(BaseOAuthClient):
|
|||
return self.source.consumer_secret
|
||||
|
||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||
"Fetch access token from callback request."
|
||||
"""Fetch access token from callback request."""
|
||||
callback = self.request.build_absolute_uri(self.callback or self.request.path)
|
||||
if not self.check_application_state():
|
||||
LOGGER.warning("Application state check failed.")
|
||||
|
@ -87,7 +87,7 @@ class OAuth2Client(BaseOAuthClient):
|
|||
return response.json()
|
||||
|
||||
def get_redirect_args(self) -> dict[str, str]:
|
||||
"Get request parameters for redirect url."
|
||||
"""Get request parameters for redirect url."""
|
||||
callback = self.request.build_absolute_uri(self.callback)
|
||||
client_id: str = self.get_client_id()
|
||||
args: dict[str, str] = {
|
||||
|
@ -102,7 +102,7 @@ class OAuth2Client(BaseOAuthClient):
|
|||
return args
|
||||
|
||||
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||
"Parse token and secret from raw token response."
|
||||
"""Parse token and secret from raw token response."""
|
||||
# Load as json first then parse as query string
|
||||
try:
|
||||
token_data = loads(raw_token)
|
||||
|
@ -112,7 +112,7 @@ class OAuth2Client(BaseOAuthClient):
|
|||
return token_data
|
||||
|
||||
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
||||
"Build remote url request. Constructs necessary auth."
|
||||
"""Build remote url request. Constructs necessary auth."""
|
||||
if "token" in kwargs:
|
||||
token = kwargs.pop("token")
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""saml sp views"""
|
||||
from urllib.parse import ParseResult, parse_qsl, urlparse, urlunparse
|
||||
from urllib.parse import parse_qsl, urlparse, urlunparse
|
||||
|
||||
from django.contrib.auth import logout
|
||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||
|
@ -112,17 +112,8 @@ class InitiateView(View):
|
|||
url_kwargs = dict(parse_qsl(sso_url.query))
|
||||
# ... and update it with the SAML args
|
||||
url_kwargs.update(auth_n_req.build_auth_n_detached())
|
||||
# Encode it back into a string
|
||||
res = ParseResult(
|
||||
scheme=sso_url.scheme,
|
||||
netloc=sso_url.netloc,
|
||||
path=sso_url.path,
|
||||
params=sso_url.params,
|
||||
query=urlencode(url_kwargs),
|
||||
fragment=sso_url.fragment,
|
||||
)
|
||||
# and merge it back into a URL
|
||||
final_url = urlunparse(res)
|
||||
# Update the url
|
||||
final_url = urlunparse(sso_url._replace(query=urlencode(url_kwargs)))
|
||||
return redirect(final_url)
|
||||
# As POST Binding we show a form
|
||||
try:
|
||||
|
|
|
@ -27,7 +27,7 @@ func (ws *WebServer) APISentryProxy(rw http.ResponseWriter, r *http.Request) {
|
|||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
lines := strings.Split(string(fb.Bytes()), "\n")
|
||||
lines := strings.Split(fb.String(), "\n")
|
||||
if len(lines) < 1 {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
|
|
Reference in a new issue