providers/oauth2: launch url: if URL parsing fails, return no launch URL (#5918)
* providers/oauth2: launch url: if URL parsing fails, return no launch URL Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * add test Signed-off-by: Jens Langhammer <jens@goauthentik.io> * only get provider launch URL when no url is set Signed-off-by: Jens Langhammer <jens@goauthentik.io> * only catch value error Signed-off-by: Jens Langhammer <jens@goauthentik.io> * format Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> Signed-off-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
587385587c
commit
0041cf88f4
|
@ -376,10 +376,10 @@ class Application(SerializerModel, PolicyBindingModel):
|
|||
def get_launch_url(self, user: Optional["User"] = None) -> Optional[str]:
|
||||
"""Get launch URL if set, otherwise attempt to get launch URL based on provider."""
|
||||
url = None
|
||||
if provider := self.get_provider():
|
||||
url = provider.launch_url
|
||||
if self.meta_launch_url:
|
||||
url = self.meta_launch_url
|
||||
elif provider := self.get_provider():
|
||||
url = provider.launch_url
|
||||
if user and url:
|
||||
if isinstance(user, SimpleLazyObject):
|
||||
user._setup()
|
||||
|
|
|
@ -17,6 +17,7 @@ from django.urls import reverse
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
from jwt import encode
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
|
@ -26,6 +27,8 @@ from authentik.lib.utils.time import timedelta_string_validator
|
|||
from authentik.providers.oauth2.id_token import IDToken, SubModes
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def generate_client_secret() -> str:
|
||||
"""Generate client secret with adequate length"""
|
||||
|
@ -251,8 +254,12 @@ class OAuth2Provider(Provider):
|
|||
if self.redirect_uris == "":
|
||||
return None
|
||||
main_url = self.redirect_uris.split("\n", maxsplit=1)[0]
|
||||
launch_url = urlparse(main_url)._replace(path="")
|
||||
return urlunparse(launch_url)
|
||||
try:
|
||||
launch_url = urlparse(main_url)._replace(path="")
|
||||
return urlunparse(launch_url)
|
||||
except ValueError as exc:
|
||||
LOGGER.warning("Failed to format launch url", exc=exc)
|
||||
return None
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
"""Test OAuth2 API"""
|
||||
from json import loads
|
||||
from sys import version_info
|
||||
from unittest import skipUnless
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
@ -42,3 +44,14 @@ class TestAPI(APITestCase):
|
|||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content.decode())
|
||||
self.assertEqual(body["issuer"], "http://testserver/application/o/test/")
|
||||
|
||||
# https://github.com/goauthentik/authentik/pull/5918
|
||||
@skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up")
|
||||
def test_launch_url(self):
|
||||
"""Test launch_url"""
|
||||
self.provider.redirect_uris = (
|
||||
"https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/\n"
|
||||
)
|
||||
self.provider.save()
|
||||
self.provider.refresh_from_db()
|
||||
self.assertIsNone(self.provider.launch_url)
|
||||
|
|
Reference in a new issue