providers/oauth2: if a redirect_uri cannot be parsed as regex, compare strict (#3070)
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
4d8021c403
commit
0cad56ec73
|
@ -47,11 +47,11 @@ def create_test_tenant() -> Tenant:
|
|||
|
||||
def create_test_cert() -> CertificateKeyPair:
|
||||
"""Generate a certificate for testing"""
|
||||
CertificateKeyPair.objects.filter(name="goauthentik.io").delete()
|
||||
builder = CertificateBuilder()
|
||||
builder.common_name = "goauthentik.io"
|
||||
builder.build(
|
||||
subject_alt_names=["goauthentik.io"],
|
||||
validity_days=360,
|
||||
)
|
||||
builder.name = generate_id()
|
||||
return builder.save()
|
||||
|
|
|
@ -53,10 +53,7 @@ class CertificateBuilder:
|
|||
.subject_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(
|
||||
NameOID.COMMON_NAME,
|
||||
self.common_name,
|
||||
),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, self.common_name),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"),
|
||||
]
|
||||
|
@ -65,10 +62,7 @@ class CertificateBuilder:
|
|||
.issuer_name(
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(
|
||||
NameOID.COMMON_NAME,
|
||||
f"authentik {__version__}",
|
||||
),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, f"authentik {__version__}"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@ from django.test import RequestFactory
|
|||
from django.urls import reverse
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.flows.challenge import ChallengeTypes
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
||||
|
@ -39,7 +39,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
def test_request(self):
|
||||
"""test request param"""
|
||||
OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid/Foo",
|
||||
|
@ -59,7 +59,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
def test_invalid_redirect_uri(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid",
|
||||
|
@ -78,10 +78,55 @@ class TestAuthorize(OAuthTestCase):
|
|||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
|
||||
def test_invalid_redirect_uri_empty(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="",
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "+",
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
provider.refresh_from_db()
|
||||
self.assertEqual(provider.redirect_uris, "+")
|
||||
|
||||
def test_invalid_redirect_uri_regex(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid?",
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
|
||||
def test_redirect_uri_invalid_regex(self):
|
||||
"""test missing/invalid redirect URI (invalid regex)"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="+",
|
||||
|
@ -103,7 +148,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
def test_empty_redirect_uri(self):
|
||||
"""test empty redirect URI (configure in provider)"""
|
||||
OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
)
|
||||
|
@ -123,7 +168,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
def test_response_type(self):
|
||||
"""test response_type"""
|
||||
OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid/Foo",
|
||||
|
@ -201,7 +246,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
"""Test full authorization"""
|
||||
flow = create_test_flow()
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=flow,
|
||||
redirect_uris="foo://localhost",
|
||||
|
@ -237,12 +282,12 @@ class TestAuthorize(OAuthTestCase):
|
|||
"""Test full authorization"""
|
||||
flow = create_test_flow()
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=flow,
|
||||
redirect_uris="http://localhost",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
|
@ -281,12 +326,12 @@ class TestAuthorize(OAuthTestCase):
|
|||
"""Test full authorization (form_post response)"""
|
||||
flow = create_test_flow()
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=flow,
|
||||
redirect_uris="http://localhost",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
|
|
|
@ -5,7 +5,7 @@ from django.test import RequestFactory
|
|||
from django.urls import reverse
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.providers.oauth2.constants import (
|
||||
|
@ -24,17 +24,17 @@ class TestToken(OAuthTestCase):
|
|||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.factory = RequestFactory()
|
||||
self.app = Application.objects.create(name="test", slug="test")
|
||||
self.app = Application.objects.create(name=generate_id(), slug="test")
|
||||
|
||||
def test_request_auth_code(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id=generate_id(),
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://testserver",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
|
@ -56,12 +56,12 @@ class TestToken(OAuthTestCase):
|
|||
def test_request_auth_code_invalid(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id=generate_id(),
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://testserver",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
request = self.factory.post(
|
||||
|
@ -79,12 +79,12 @@ class TestToken(OAuthTestCase):
|
|||
def test_request_refresh_token(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id=generate_id(),
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
|
@ -108,12 +108,12 @@ class TestToken(OAuthTestCase):
|
|||
def test_auth_code_view(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id=generate_id(),
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
# Needs to be assigned to an application for iss to be set
|
||||
self.app.provider = provider
|
||||
|
@ -150,12 +150,12 @@ class TestToken(OAuthTestCase):
|
|||
def test_refresh_token_view(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id=generate_id(),
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
# Needs to be assigned to an application for iss to be set
|
||||
self.app.provider = provider
|
||||
|
@ -199,12 +199,12 @@ class TestToken(OAuthTestCase):
|
|||
def test_refresh_token_view_invalid_origin(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id=generate_id(),
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
user = create_test_admin_user()
|
||||
|
@ -244,12 +244,12 @@ class TestToken(OAuthTestCase):
|
|||
def test_refresh_token_revoke(self):
|
||||
"""test request param"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name="test",
|
||||
name=generate_id(),
|
||||
client_id=generate_id(),
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://testserver",
|
||||
signing_key=create_test_cert(),
|
||||
signing_key=self.keypair,
|
||||
)
|
||||
# Needs to be assigned to an application for iss to be set
|
||||
self.app.provider = provider
|
||||
|
|
|
@ -2,12 +2,15 @@
|
|||
from django.test import TestCase
|
||||
from jwt import decode
|
||||
|
||||
from authentik.core.tests.utils import create_test_cert
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider, RefreshToken
|
||||
|
||||
|
||||
class OAuthTestCase(TestCase):
|
||||
"""OAuth test helpers"""
|
||||
|
||||
keypair: CertificateKeyPair
|
||||
required_jwt_keys = [
|
||||
"exp",
|
||||
"iat",
|
||||
|
@ -17,6 +20,11 @@ class OAuthTestCase(TestCase):
|
|||
"iss",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.keypair = create_test_cert()
|
||||
super().setUpClass()
|
||||
|
||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
|
||||
"""Validate that all required fields are set"""
|
||||
key, alg = provider.get_jwt_key()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from re import error as RegexError
|
||||
from re import escape, fullmatch
|
||||
from re import fullmatch
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit
|
||||
from uuid import uuid4
|
||||
|
@ -181,7 +181,7 @@ class OAuthAuthorizationParams:
|
|||
|
||||
if self.provider.redirect_uris == "":
|
||||
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
||||
self.provider.redirect_uris = escape(self.redirect_uri)
|
||||
self.provider.redirect_uris = self.redirect_uri
|
||||
self.provider.save()
|
||||
allowed_redirect_urls = self.provider.redirect_uris.split()
|
||||
|
||||
|
@ -194,14 +194,20 @@ class OAuthAuthorizationParams:
|
|||
try:
|
||||
if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
|
||||
LOGGER.warning(
|
||||
"Invalid redirect uri",
|
||||
"Invalid redirect uri (regex comparison)",
|
||||
redirect_uri=self.redirect_uri,
|
||||
expected=allowed_redirect_urls,
|
||||
)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
except RegexError as exc:
|
||||
LOGGER.warning("Invalid regular expression configured", exc=exc)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
LOGGER.info("Failed to parse regular expression, checking directly", exc=exc)
|
||||
if not any(x == self.redirect_uri for x in allowed_redirect_urls):
|
||||
LOGGER.warning(
|
||||
"Invalid redirect uri (strict comparison)",
|
||||
redirect_uri=self.redirect_uri,
|
||||
expected=allowed_redirect_urls,
|
||||
)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
if self.request:
|
||||
raise AuthorizeError(
|
||||
self.redirect_uri, "request_not_supported", self.grant_type, self.state
|
||||
|
|
|
@ -154,7 +154,7 @@ class TokenParams:
|
|||
try:
|
||||
if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
|
||||
LOGGER.warning(
|
||||
"Invalid redirect uri",
|
||||
"Invalid redirect uri (regex comparison)",
|
||||
redirect_uri=self.redirect_uri,
|
||||
expected=allowed_redirect_urls,
|
||||
)
|
||||
|
@ -167,13 +167,19 @@ class TokenParams:
|
|||
).from_http(request)
|
||||
raise TokenError("invalid_client")
|
||||
except RegexError as exc:
|
||||
LOGGER.warning("Invalid regular expression configured", exc=exc)
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message="Invalid redirect_uri RegEx configured",
|
||||
provider=self.provider,
|
||||
).from_http(request)
|
||||
raise TokenError("invalid_client")
|
||||
LOGGER.info("Failed to parse regular expression, checking directly", exc=exc)
|
||||
if not any(x == self.redirect_uri for x in allowed_redirect_urls):
|
||||
LOGGER.warning(
|
||||
"Invalid redirect uri (strict comparison)",
|
||||
redirect_uri=self.redirect_uri,
|
||||
expected=allowed_redirect_urls,
|
||||
)
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message="Invalid redirect_uri configured",
|
||||
provider=self.provider,
|
||||
).from_http(request)
|
||||
raise TokenError("invalid_client")
|
||||
|
||||
try:
|
||||
self.authorization_code = AuthorizationCode.objects.get(code=raw_code)
|
||||
|
|
Reference in a new issue