diff --git a/authentik/api/authentication.py b/authentik/api/authentication.py index 9566c3d50..3e7faaf85 100644 --- a/authentik/api/authentication.py +++ b/authentik/api/authentication.py @@ -1,6 +1,4 @@ """API Authentication""" -from base64 import b64decode -from binascii import Error from typing import Any, Optional from django.conf import settings @@ -16,38 +14,34 @@ from authentik.outposts.models import Outpost LOGGER = get_logger() -# pylint: disable=too-many-return-statements -def bearer_auth(raw_header: bytes) -> Optional[User]: - """raw_header in the Format of `Bearer dGVzdDp0ZXN0`""" - auth_credentials = raw_header.decode() +def validate_auth(header: bytes) -> str: + """Validate that the header is in a correct format, + returns type and credentials""" + auth_credentials = header.decode().strip() if auth_credentials == "" or " " not in auth_credentials: return None auth_type, _, auth_credentials = auth_credentials.partition(" ") - if auth_type.lower() not in ["basic", "bearer"]: + if auth_type.lower() != "bearer": LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower()) raise AuthenticationFailed("Unsupported authentication type") - password = auth_credentials - if auth_type.lower() == "basic": - try: - auth_credentials = b64decode(auth_credentials.encode()).decode() - except (UnicodeDecodeError, Error): - raise AuthenticationFailed("Malformed header") - # Accept credentials with username and without - if ":" in auth_credentials: - _, _, password = auth_credentials.partition(":") - else: - password = auth_credentials - if password == "": # nosec + if auth_credentials == "": # nosec raise AuthenticationFailed("Malformed header") - tokens = Token.filter_not_expired(key=password, intent=TokenIntents.INTENT_API) - if not tokens.exists(): - user = token_secret_key(password) - if not user: - raise AuthenticationFailed("Token invalid/expired") - return user + return auth_credentials + + +def bearer_auth(raw_header: bytes) -> Optional[User]: + """raw_header in the Format of `Bearer ....`""" + auth_credentials = validate_auth(raw_header) + # first, check traditional tokens + token = Token.filter_not_expired(key=auth_credentials, intent=TokenIntents.INTENT_API).first() if hasattr(LOCAL, "authentik"): LOCAL.authentik[KEY_AUTH_VIA] = "api_token" - return tokens.first().user + if token: + return token.user + user = token_secret_key(auth_credentials) + if user: + return user + raise AuthenticationFailed("Token invalid/expired") def token_secret_key(value: str) -> Optional[User]: diff --git a/authentik/api/tests/test_auth.py b/authentik/api/tests/test_auth.py index b965487c9..ef49d3dee 100644 --- a/authentik/api/tests/test_auth.py +++ b/authentik/api/tests/test_auth.py @@ -14,12 +14,6 @@ from authentik.outposts.managed import OutpostManager class TestAPIAuth(TestCase): """Test API Authentication""" - def test_valid_basic(self): - """Test valid token""" - token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user()) - auth = b64encode(f":{token.key}".encode()).decode() - self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user) - def test_valid_bearer(self): """Test valid token""" token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user()) @@ -30,16 +24,6 @@ class TestAPIAuth(TestCase): with self.assertRaises(AuthenticationFailed): bearer_auth("foo bar".encode()) - def test_invalid_decode(self): - """Test invalid bas64""" - with self.assertRaises(AuthenticationFailed): - bearer_auth("Basic bar".encode()) - - def test_invalid_empty_password(self): - """Test invalid with empty password""" - with self.assertRaises(AuthenticationFailed): - bearer_auth("Basic :".encode()) - def test_invalid_no_token(self): """Test invalid with no token""" with self.assertRaises(AuthenticationFailed):