sources/oauth: fix redirect loop for source with non-configurable URLs

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-17 19:06:12 +02:00
parent 476e57daa2
commit d2dd7d1366
3 changed files with 23 additions and 4 deletions

View File

@ -9,6 +9,7 @@ from requests.models import Response
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import __version__ from authentik import __version__
from authentik.events.models import Event, EventAction
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger() LOGGER = get_logger()
@ -59,7 +60,16 @@ class BaseOAuthClient:
args.update(additional) args.update(additional)
params = urlencode(args) params = urlencode(args)
LOGGER.info("redirect args", **args) LOGGER.info("redirect args", **args)
return f"{self.source.authorization_url}?{params}" base_url = self.source.authorization_url
if not self.source.type.urls_customizable:
base_url = self.source.type.authorization_url
if base_url == "":
Event.new(
EventAction.CONFIGURATION_ERROR,
source=self.source,
message="Source has an empty authorization URL.",
).save()
return f"{base_url}?{params}"
def parse_raw_token(self, raw_token: str) -> dict[str, Any]: def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."

View File

@ -28,9 +28,12 @@ class OAuthClient(BaseOAuthClient):
if raw_token is not None and verifier is not None: if raw_token is not None and verifier is not None:
token = self.parse_raw_token(raw_token) token = self.parse_raw_token(raw_token)
try: try:
access_token_url: str = self.source.access_token_url
if not self.source.type.urls_customizable:
access_token_url = self.source.type.access_token_url or ""
response = self.do_request( response = self.do_request(
"post", "post",
self.source.access_token_url, access_token_url,
token=token, token=token,
headers=self._default_headers, headers=self._default_headers,
oauth_verifier=verifier, oauth_verifier=verifier,
@ -48,9 +51,12 @@ class OAuthClient(BaseOAuthClient):
"Fetch the OAuth request token. Only required for OAuth 1.0." "Fetch the OAuth request token. Only required for OAuth 1.0."
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
try: try:
request_token_url: str = self.source.request_token_url
if not self.source.type.urls_customizable:
request_token_url = self.source.type.request_token_url or ""
response = self.do_request( response = self.do_request(
"post", "post",
self.source.request_token_url, request_token_url,
headers=self._default_headers, headers=self._default_headers,
oauth_callback=callback, oauth_callback=callback,
) )

View File

@ -56,9 +56,12 @@ class OAuth2Client(BaseOAuthClient):
LOGGER.warning("No code returned by the source") LOGGER.warning("No code returned by the source")
return None return None
try: try:
access_token_url = self.source.access_token_url
if not self.source.type.urls_customizable:
access_token_url = self.source.type.access_token_url or ""
response = self.session.request( response = self.session.request(
"post", "post",
self.source.access_token_url, access_token_url,
data=args, data=args,
headers=self._default_headers, headers=self._default_headers,
) )