providers/oauth2: if a redirect_uri cannot be parsed as regex, compare strict (#3070)

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens L 2022-06-10 23:32:57 +02:00 committed by GitHub
parent 4d8021c403
commit 0cad56ec73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 49 deletions

View File

@ -47,11 +47,11 @@ def create_test_tenant() -> Tenant:
def create_test_cert() -> CertificateKeyPair: def create_test_cert() -> CertificateKeyPair:
"""Generate a certificate for testing""" """Generate a certificate for testing"""
CertificateKeyPair.objects.filter(name="goauthentik.io").delete()
builder = CertificateBuilder() builder = CertificateBuilder()
builder.common_name = "goauthentik.io" builder.common_name = "goauthentik.io"
builder.build( builder.build(
subject_alt_names=["goauthentik.io"], subject_alt_names=["goauthentik.io"],
validity_days=360, validity_days=360,
) )
builder.name = generate_id()
return builder.save() return builder.save()

View File

@ -53,10 +53,7 @@ class CertificateBuilder:
.subject_name( .subject_name(
x509.Name( x509.Name(
[ [
x509.NameAttribute( x509.NameAttribute(NameOID.COMMON_NAME, self.common_name),
NameOID.COMMON_NAME,
self.common_name,
),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"), x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"),
] ]
@ -65,10 +62,7 @@ class CertificateBuilder:
.issuer_name( .issuer_name(
x509.Name( x509.Name(
[ [
x509.NameAttribute( x509.NameAttribute(NameOID.COMMON_NAME, f"authentik {__version__}"),
NameOID.COMMON_NAME,
f"authentik {__version__}",
),
] ]
) )
) )

View File

@ -3,7 +3,7 @@ from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.challenge import ChallengeTypes from authentik.flows.challenge import ChallengeTypes
from authentik.lib.generators import generate_id, generate_key from authentik.lib.generators import generate_id, generate_key
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
@ -39,7 +39,7 @@ class TestAuthorize(OAuthTestCase):
def test_request(self): def test_request(self):
"""test request param""" """test request param"""
OAuth2Provider.objects.create( OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid/Foo", redirect_uris="http://local.invalid/Foo",
@ -59,7 +59,7 @@ class TestAuthorize(OAuthTestCase):
def test_invalid_redirect_uri(self): def test_invalid_redirect_uri(self):
"""test missing/invalid redirect URI""" """test missing/invalid redirect URI"""
OAuth2Provider.objects.create( OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
@ -78,10 +78,55 @@ class TestAuthorize(OAuthTestCase):
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
def test_invalid_redirect_uri_empty(self):
"""test missing/invalid redirect URI"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="",
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "+",
},
)
OAuthAuthorizationParams.from_request(request)
provider.refresh_from_db()
self.assertEqual(provider.redirect_uris, "+")
def test_invalid_redirect_uri_regex(self): def test_invalid_redirect_uri_regex(self):
"""test missing/invalid redirect URI""" """test missing/invalid redirect URI"""
OAuth2Provider.objects.create( OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid?",
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "http://localhost",
},
)
OAuthAuthorizationParams.from_request(request)
def test_redirect_uri_invalid_regex(self):
"""test missing/invalid redirect URI (invalid regex)"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="+", redirect_uris="+",
@ -103,7 +148,7 @@ class TestAuthorize(OAuthTestCase):
def test_empty_redirect_uri(self): def test_empty_redirect_uri(self):
"""test empty redirect URI (configure in provider)""" """test empty redirect URI (configure in provider)"""
OAuth2Provider.objects.create( OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
) )
@ -123,7 +168,7 @@ class TestAuthorize(OAuthTestCase):
def test_response_type(self): def test_response_type(self):
"""test response_type""" """test response_type"""
OAuth2Provider.objects.create( OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid/Foo", redirect_uris="http://local.invalid/Foo",
@ -201,7 +246,7 @@ class TestAuthorize(OAuthTestCase):
"""Test full authorization""" """Test full authorization"""
flow = create_test_flow() flow = create_test_flow()
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test", client_id="test",
authorization_flow=flow, authorization_flow=flow,
redirect_uris="foo://localhost", redirect_uris="foo://localhost",
@ -237,12 +282,12 @@ class TestAuthorize(OAuthTestCase):
"""Test full authorization""" """Test full authorization"""
flow = create_test_flow() flow = create_test_flow()
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test", client_id="test",
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=flow, authorization_flow=flow,
redirect_uris="http://localhost", redirect_uris="http://localhost",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
Application.objects.create(name="app", slug="app", provider=provider) Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id() state = generate_id()
@ -281,12 +326,12 @@ class TestAuthorize(OAuthTestCase):
"""Test full authorization (form_post response)""" """Test full authorization (form_post response)"""
flow = create_test_flow() flow = create_test_flow()
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test", client_id="test",
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=flow, authorization_flow=flow,
redirect_uris="http://localhost", redirect_uris="http://localhost",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
Application.objects.create(name="app", slug="app", provider=provider) Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id() state = generate_id()

View File

@ -5,7 +5,7 @@ from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id, generate_key from authentik.lib.generators import generate_id, generate_key
from authentik.providers.oauth2.constants import ( from authentik.providers.oauth2.constants import (
@ -24,17 +24,17 @@ class TestToken(OAuthTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.factory = RequestFactory() self.factory = RequestFactory()
self.app = Application.objects.create(name="test", slug="test") self.app = Application.objects.create(name=generate_id(), slug="test")
def test_request_auth_code(self): def test_request_auth_code(self):
"""test request param""" """test request param"""
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id=generate_id(), client_id=generate_id(),
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://testserver", redirect_uris="http://testserver",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user() user = create_test_admin_user()
@ -56,12 +56,12 @@ class TestToken(OAuthTestCase):
def test_request_auth_code_invalid(self): def test_request_auth_code_invalid(self):
"""test request param""" """test request param"""
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id=generate_id(), client_id=generate_id(),
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://testserver", redirect_uris="http://testserver",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
request = self.factory.post( request = self.factory.post(
@ -79,12 +79,12 @@ class TestToken(OAuthTestCase):
def test_request_refresh_token(self): def test_request_refresh_token(self):
"""test request param""" """test request param"""
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id=generate_id(), client_id=generate_id(),
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user() user = create_test_admin_user()
@ -108,12 +108,12 @@ class TestToken(OAuthTestCase):
def test_auth_code_view(self): def test_auth_code_view(self):
"""test request param""" """test request param"""
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id=generate_id(), client_id=generate_id(),
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
# Needs to be assigned to an application for iss to be set # Needs to be assigned to an application for iss to be set
self.app.provider = provider self.app.provider = provider
@ -150,12 +150,12 @@ class TestToken(OAuthTestCase):
def test_refresh_token_view(self): def test_refresh_token_view(self):
"""test request param""" """test request param"""
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id=generate_id(), client_id=generate_id(),
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
# Needs to be assigned to an application for iss to be set # Needs to be assigned to an application for iss to be set
self.app.provider = provider self.app.provider = provider
@ -199,12 +199,12 @@ class TestToken(OAuthTestCase):
def test_refresh_token_view_invalid_origin(self): def test_refresh_token_view_invalid_origin(self):
"""test request param""" """test request param"""
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id=generate_id(), client_id=generate_id(),
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user() user = create_test_admin_user()
@ -244,12 +244,12 @@ class TestToken(OAuthTestCase):
def test_refresh_token_revoke(self): def test_refresh_token_revoke(self):
"""test request param""" """test request param"""
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id=generate_id(), client_id=generate_id(),
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris="http://testserver", redirect_uris="http://testserver",
signing_key=create_test_cert(), signing_key=self.keypair,
) )
# Needs to be assigned to an application for iss to be set # Needs to be assigned to an application for iss to be set
self.app.provider = provider self.app.provider = provider

View File

@ -2,12 +2,15 @@
from django.test import TestCase from django.test import TestCase
from jwt import decode from jwt import decode
from authentik.core.tests.utils import create_test_cert
from authentik.crypto.models import CertificateKeyPair
from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider, RefreshToken from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider, RefreshToken
class OAuthTestCase(TestCase): class OAuthTestCase(TestCase):
"""OAuth test helpers""" """OAuth test helpers"""
keypair: CertificateKeyPair
required_jwt_keys = [ required_jwt_keys = [
"exp", "exp",
"iat", "iat",
@ -17,6 +20,11 @@ class OAuthTestCase(TestCase):
"iss", "iss",
] ]
@classmethod
def setUpClass(cls) -> None:
cls.keypair = create_test_cert()
super().setUpClass()
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
"""Validate that all required fields are set""" """Validate that all required fields are set"""
key, alg = provider.get_jwt_key() key, alg = provider.get_jwt_key()

View File

@ -2,7 +2,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from re import error as RegexError from re import error as RegexError
from re import escape, fullmatch from re import fullmatch
from typing import Optional from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit
from uuid import uuid4 from uuid import uuid4
@ -181,7 +181,7 @@ class OAuthAuthorizationParams:
if self.provider.redirect_uris == "": if self.provider.redirect_uris == "":
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
self.provider.redirect_uris = escape(self.redirect_uri) self.provider.redirect_uris = self.redirect_uri
self.provider.save() self.provider.save()
allowed_redirect_urls = self.provider.redirect_uris.split() allowed_redirect_urls = self.provider.redirect_uris.split()
@ -194,14 +194,20 @@ class OAuthAuthorizationParams:
try: try:
if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
LOGGER.warning( LOGGER.warning(
"Invalid redirect uri", "Invalid redirect uri (regex comparison)",
redirect_uri=self.redirect_uri, redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls, expected=allowed_redirect_urls,
) )
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
except RegexError as exc: except RegexError as exc:
LOGGER.warning("Invalid regular expression configured", exc=exc) LOGGER.info("Failed to parse regular expression, checking directly", exc=exc)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) if not any(x == self.redirect_uri for x in allowed_redirect_urls):
LOGGER.warning(
"Invalid redirect uri (strict comparison)",
redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls,
)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
if self.request: if self.request:
raise AuthorizeError( raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state self.redirect_uri, "request_not_supported", self.grant_type, self.state

View File

@ -154,7 +154,7 @@ class TokenParams:
try: try:
if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls): if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
LOGGER.warning( LOGGER.warning(
"Invalid redirect uri", "Invalid redirect uri (regex comparison)",
redirect_uri=self.redirect_uri, redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls, expected=allowed_redirect_urls,
) )
@ -167,13 +167,19 @@ class TokenParams:
).from_http(request) ).from_http(request)
raise TokenError("invalid_client") raise TokenError("invalid_client")
except RegexError as exc: except RegexError as exc:
LOGGER.warning("Invalid regular expression configured", exc=exc) LOGGER.info("Failed to parse regular expression, checking directly", exc=exc)
Event.new( if not any(x == self.redirect_uri for x in allowed_redirect_urls):
EventAction.CONFIGURATION_ERROR, LOGGER.warning(
message="Invalid redirect_uri RegEx configured", "Invalid redirect uri (strict comparison)",
provider=self.provider, redirect_uri=self.redirect_uri,
).from_http(request) expected=allowed_redirect_urls,
raise TokenError("invalid_client") )
Event.new(
EventAction.CONFIGURATION_ERROR,
message="Invalid redirect_uri configured",
provider=self.provider,
).from_http(request)
raise TokenError("invalid_client")
try: try:
self.authorization_code = AuthorizationCode.objects.get(code=raw_code) self.authorization_code = AuthorizationCode.objects.get(code=raw_code)