sources/oauth: correctly concatenate URLs to allow custom parameters to be included
closes #3374 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
6356ddd9f3
commit
4c9878313c
|
@ -228,9 +228,9 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
||||||
return DEFAULT_AVATAR
|
return DEFAULT_AVATAR
|
||||||
if mode.startswith("attributes."):
|
if mode.startswith("attributes."):
|
||||||
return get_path_from_dict(self.attributes, mode[11:], default=DEFAULT_AVATAR)
|
return get_path_from_dict(self.attributes, mode[11:], default=DEFAULT_AVATAR)
|
||||||
|
# gravatar uses md5 for their URLs, so md5 can't be avoided
|
||||||
mail_hash = md5(self.email.lower().encode("utf-8")).hexdigest() # nosec
|
mail_hash = md5(self.email.lower().encode("utf-8")).hexdigest() # nosec
|
||||||
if mode == "gravatar":
|
if mode == "gravatar":
|
||||||
# gravatar uses md5 for their URLs, so md5 can't be avoided
|
|
||||||
parameters = [
|
parameters = [
|
||||||
("s", "158"),
|
("s", "158"),
|
||||||
("r", "g"),
|
("r", "g"),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""OAuth Clients"""
|
"""OAuth Clients"""
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from urllib.parse import quote, urlencode
|
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
@ -32,11 +32,11 @@ class BaseOAuthClient:
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||||
"Fetch access token from callback request."
|
"""Fetch access token from callback request."""
|
||||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||||
|
|
||||||
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
||||||
"Fetch user profile information."
|
"""Fetch user profile information."""
|
||||||
profile_url = self.source.type.profile_url or ""
|
profile_url = self.source.type.profile_url or ""
|
||||||
if self.source.type.urls_customizable and self.source.profile_url:
|
if self.source.type.urls_customizable and self.source.profile_url:
|
||||||
profile_url = self.source.profile_url
|
profile_url = self.source.profile_url
|
||||||
|
@ -50,19 +50,11 @@ class BaseOAuthClient:
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def get_redirect_args(self) -> dict[str, str]:
|
def get_redirect_args(self) -> dict[str, str]:
|
||||||
"Get request parameters for redirect url."
|
"""Get request parameters for redirect url."""
|
||||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||||
|
|
||||||
def get_redirect_url(self, parameters=None):
|
def get_redirect_url(self, parameters=None):
|
||||||
"Build authentication redirect url."
|
"""Build authentication redirect url."""
|
||||||
args = self.get_redirect_args()
|
|
||||||
additional = parameters or {}
|
|
||||||
args.update(additional)
|
|
||||||
# Special handling for scope, since it's set as array
|
|
||||||
# to make additional scopes easier
|
|
||||||
args["scope"] = " ".join(sorted(set(args["scope"])))
|
|
||||||
params = urlencode(args, quote_via=quote)
|
|
||||||
LOGGER.info("redirect args", **args)
|
|
||||||
authorization_url = self.source.type.authorization_url or ""
|
authorization_url = self.source.type.authorization_url or ""
|
||||||
if self.source.type.urls_customizable and self.source.authorization_url:
|
if self.source.type.urls_customizable and self.source.authorization_url:
|
||||||
authorization_url = self.source.authorization_url
|
authorization_url = self.source.authorization_url
|
||||||
|
@ -72,10 +64,20 @@ class BaseOAuthClient:
|
||||||
source=self.source,
|
source=self.source,
|
||||||
message="Source has an empty authorization URL.",
|
message="Source has an empty authorization URL.",
|
||||||
).save()
|
).save()
|
||||||
return f"{authorization_url}?{params}"
|
parsed_url = urlparse(authorization_url)
|
||||||
|
parsed_args = parse_qs(parsed_url.query)
|
||||||
|
args = self.get_redirect_args()
|
||||||
|
args.update(parameters or {})
|
||||||
|
args.update(parsed_args)
|
||||||
|
# Special handling for scope, since it's set as array
|
||||||
|
# to make additional scopes easier
|
||||||
|
args["scope"] = " ".join(sorted(set(args["scope"])))
|
||||||
|
params = urlencode(args, quote_via=quote)
|
||||||
|
LOGGER.info("redirect args", **args)
|
||||||
|
return urlunparse(parsed_url._replace(query=params))
|
||||||
|
|
||||||
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||||
"Parse token and secret from raw token response."
|
"""Parse token and secret from raw token response."""
|
||||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||||
|
|
||||||
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
||||||
|
|
|
@ -21,7 +21,7 @@ class OAuthClient(BaseOAuthClient):
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||||
"Fetch access token from callback request."
|
"""Fetch access token from callback request."""
|
||||||
raw_token = self.request.session.get(self.session_key, None)
|
raw_token = self.request.session.get(self.session_key, None)
|
||||||
verifier = self.request.GET.get("oauth_verifier", None)
|
verifier = self.request.GET.get("oauth_verifier", None)
|
||||||
callback = self.request.build_absolute_uri(self.callback)
|
callback = self.request.build_absolute_uri(self.callback)
|
||||||
|
@ -48,7 +48,7 @@ class OAuthClient(BaseOAuthClient):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_request_token(self) -> str:
|
def get_request_token(self) -> str:
|
||||||
"Fetch the OAuth request token. Only required for OAuth 1.0."
|
"""Fetch the OAuth request token. Only required for OAuth 1.0."""
|
||||||
callback = self.request.build_absolute_uri(self.callback)
|
callback = self.request.build_absolute_uri(self.callback)
|
||||||
try:
|
try:
|
||||||
request_token_url = self.source.type.request_token_url or ""
|
request_token_url = self.source.type.request_token_url or ""
|
||||||
|
@ -67,7 +67,7 @@ class OAuthClient(BaseOAuthClient):
|
||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
def get_redirect_args(self) -> dict[str, Any]:
|
def get_redirect_args(self) -> dict[str, Any]:
|
||||||
"Get request parameters for redirect url."
|
"""Get request parameters for redirect url."""
|
||||||
callback = self.request.build_absolute_uri(self.callback)
|
callback = self.request.build_absolute_uri(self.callback)
|
||||||
raw_token = self.get_request_token()
|
raw_token = self.get_request_token()
|
||||||
token = self.parse_raw_token(raw_token)
|
token = self.parse_raw_token(raw_token)
|
||||||
|
@ -78,11 +78,11 @@ class OAuthClient(BaseOAuthClient):
|
||||||
}
|
}
|
||||||
|
|
||||||
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||||
"Parse token and secret from raw token response."
|
"""Parse token and secret from raw token response."""
|
||||||
return dict(parse_qsl(raw_token))
|
return dict(parse_qsl(raw_token))
|
||||||
|
|
||||||
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
||||||
"Build remote url request. Constructs necessary auth."
|
"""Build remote url request. Constructs necessary auth."""
|
||||||
resource_owner_key = None
|
resource_owner_key = None
|
||||||
resource_owner_secret = None
|
resource_owner_secret = None
|
||||||
if "token" in kwargs:
|
if "token" in kwargs:
|
||||||
|
|
|
@ -28,7 +28,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
return self.request.GET.get(key, default)
|
return self.request.GET.get(key, default)
|
||||||
|
|
||||||
def check_application_state(self) -> bool:
|
def check_application_state(self) -> bool:
|
||||||
"Check optional state parameter."
|
"""Check optional state parameter."""
|
||||||
stored = self.request.session.get(self.session_key, None)
|
stored = self.request.session.get(self.session_key, None)
|
||||||
returned = self.get_request_arg("state", None)
|
returned = self.get_request_arg("state", None)
|
||||||
check = False
|
check = False
|
||||||
|
@ -42,7 +42,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
return check
|
return check
|
||||||
|
|
||||||
def get_application_state(self) -> str:
|
def get_application_state(self) -> str:
|
||||||
"Generate state optional parameter."
|
"""Generate state optional parameter."""
|
||||||
return get_random_string(32)
|
return get_random_string(32)
|
||||||
|
|
||||||
def get_client_id(self) -> str:
|
def get_client_id(self) -> str:
|
||||||
|
@ -54,7 +54,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
return self.source.consumer_secret
|
return self.source.consumer_secret
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||||
"Fetch access token from callback request."
|
"""Fetch access token from callback request."""
|
||||||
callback = self.request.build_absolute_uri(self.callback or self.request.path)
|
callback = self.request.build_absolute_uri(self.callback or self.request.path)
|
||||||
if not self.check_application_state():
|
if not self.check_application_state():
|
||||||
LOGGER.warning("Application state check failed.")
|
LOGGER.warning("Application state check failed.")
|
||||||
|
@ -87,7 +87,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def get_redirect_args(self) -> dict[str, str]:
|
def get_redirect_args(self) -> dict[str, str]:
|
||||||
"Get request parameters for redirect url."
|
"""Get request parameters for redirect url."""
|
||||||
callback = self.request.build_absolute_uri(self.callback)
|
callback = self.request.build_absolute_uri(self.callback)
|
||||||
client_id: str = self.get_client_id()
|
client_id: str = self.get_client_id()
|
||||||
args: dict[str, str] = {
|
args: dict[str, str] = {
|
||||||
|
@ -102,7 +102,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||||
"Parse token and secret from raw token response."
|
"""Parse token and secret from raw token response."""
|
||||||
# Load as json first then parse as query string
|
# Load as json first then parse as query string
|
||||||
try:
|
try:
|
||||||
token_data = loads(raw_token)
|
token_data = loads(raw_token)
|
||||||
|
@ -112,7 +112,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
return token_data
|
return token_data
|
||||||
|
|
||||||
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
def do_request(self, method: str, url: str, **kwargs) -> Response:
|
||||||
"Build remote url request. Constructs necessary auth."
|
"""Build remote url request. Constructs necessary auth."""
|
||||||
if "token" in kwargs:
|
if "token" in kwargs:
|
||||||
token = kwargs.pop("token")
|
token = kwargs.pop("token")
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""saml sp views"""
|
"""saml sp views"""
|
||||||
from urllib.parse import ParseResult, parse_qsl, urlparse, urlunparse
|
from urllib.parse import parse_qsl, urlparse, urlunparse
|
||||||
|
|
||||||
from django.contrib.auth import logout
|
from django.contrib.auth import logout
|
||||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||||
|
@ -112,17 +112,8 @@ class InitiateView(View):
|
||||||
url_kwargs = dict(parse_qsl(sso_url.query))
|
url_kwargs = dict(parse_qsl(sso_url.query))
|
||||||
# ... and update it with the SAML args
|
# ... and update it with the SAML args
|
||||||
url_kwargs.update(auth_n_req.build_auth_n_detached())
|
url_kwargs.update(auth_n_req.build_auth_n_detached())
|
||||||
# Encode it back into a string
|
# Update the url
|
||||||
res = ParseResult(
|
final_url = urlunparse(sso_url._replace(query=urlencode(url_kwargs)))
|
||||||
scheme=sso_url.scheme,
|
|
||||||
netloc=sso_url.netloc,
|
|
||||||
path=sso_url.path,
|
|
||||||
params=sso_url.params,
|
|
||||||
query=urlencode(url_kwargs),
|
|
||||||
fragment=sso_url.fragment,
|
|
||||||
)
|
|
||||||
# and merge it back into a URL
|
|
||||||
final_url = urlunparse(res)
|
|
||||||
return redirect(final_url)
|
return redirect(final_url)
|
||||||
# As POST Binding we show a form
|
# As POST Binding we show a form
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -27,7 +27,7 @@ func (ws *WebServer) APISentryProxy(rw http.ResponseWriter, r *http.Request) {
|
||||||
rw.WriteHeader(http.StatusBadRequest)
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lines := strings.Split(string(fb.Bytes()), "\n")
|
lines := strings.Split(fb.String(), "\n")
|
||||||
if len(lines) < 1 {
|
if len(lines) < 1 {
|
||||||
rw.WriteHeader(http.StatusBadRequest)
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
|
Reference in New Issue