api: make 401 messages clearer

closes #755

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-19 20:46:50 +02:00
parent 837d2f6fab
commit 464a1c0536
3 changed files with 26 additions and 15 deletions

View File

@ -4,6 +4,7 @@ from binascii import Error
from typing import Any, Optional, Union from typing import Any, Optional, Union
from rest_framework.authentication import BaseAuthentication, get_authorization_header from rest_framework.authentication import BaseAuthentication, get_authorization_header
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.request import Request from rest_framework.request import Request
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -14,7 +15,7 @@ LOGGER = get_logger()
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
def token_from_header(raw_header: bytes) -> Optional[Token]: def token_from_header(raw_header: bytes) -> Optional[Token]:
"""raw_header in the Format of `Basic dGVzdDp0ZXN0`""" """raw_header in the Format of `Bearer dGVzdDp0ZXN0`"""
auth_credentials = raw_header.decode() auth_credentials = raw_header.decode()
if auth_credentials == "": if auth_credentials == "":
return None return None
@ -25,28 +26,27 @@ def token_from_header(raw_header: bytes) -> Optional[Token]:
auth_type, body = plain.split() auth_type, body = plain.split()
auth_credentials = f"{auth_type} {b64encode(body.encode()).decode()}" auth_credentials = f"{auth_type} {b64encode(body.encode()).decode()}"
except (UnicodeDecodeError, Error): except (UnicodeDecodeError, Error):
return None raise AuthenticationFailed("Malformed header")
auth_type, auth_credentials = auth_credentials.split() auth_type, auth_credentials = auth_credentials.split()
if auth_type.lower() not in ["basic", "bearer"]: if auth_type.lower() not in ["basic", "bearer"]:
LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower()) LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower())
return None raise AuthenticationFailed("Unsupported authentication type")
password = auth_credentials password = auth_credentials
if auth_type.lower() == "basic": if auth_type.lower() == "basic":
try: try:
auth_credentials = b64decode(auth_credentials.encode()).decode() auth_credentials = b64decode(auth_credentials.encode()).decode()
except (UnicodeDecodeError, Error): except (UnicodeDecodeError, Error):
return None raise AuthenticationFailed("Malformed header")
# Accept credentials with username and without # Accept credentials with username and without
if ":" in auth_credentials: if ":" in auth_credentials:
_, password = auth_credentials.split(":") _, password = auth_credentials.split(":")
else: else:
password = auth_credentials password = auth_credentials
if password == "": # nosec if password == "": # nosec
return None raise AuthenticationFailed("Malformed header")
tokens = Token.filter_not_expired(key=password, intent=TokenIntents.INTENT_API) tokens = Token.filter_not_expired(key=password, intent=TokenIntents.INTENT_API)
if not tokens.exists(): if not tokens.exists():
LOGGER.debug("Token not found") raise AuthenticationFailed("Token invalid/expired")
return None
return tokens.first() return tokens.first()
@ -58,6 +58,7 @@ class AuthentikTokenAuthentication(BaseAuthentication):
auth = get_authorization_header(request) auth = get_authorization_header(request)
token = token_from_header(auth) token = token_from_header(auth)
# None is only returned when the header isn't set.
if not token: if not token:
return None return None

View File

@ -3,6 +3,7 @@ from base64 import b64encode
from django.test import TestCase from django.test import TestCase
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from rest_framework.exceptions import AuthenticationFailed
from authentik.api.auth import token_from_header from authentik.api.auth import token_from_header
from authentik.core.models import Token, TokenIntents from authentik.core.models import Token, TokenIntents
@ -28,17 +29,21 @@ class TestAPIAuth(TestCase):
def test_invalid_type(self): def test_invalid_type(self):
"""Test invalid type""" """Test invalid type"""
self.assertIsNone(token_from_header("foo bar".encode())) with self.assertRaises(AuthenticationFailed):
token_from_header("foo bar".encode())
def test_invalid_decode(self): def test_invalid_decode(self):
"""Test invalid bas64""" """Test invalid bas64"""
self.assertIsNone(token_from_header("Basic bar".encode())) with self.assertRaises(AuthenticationFailed):
token_from_header("Basic bar".encode())
def test_invalid_empty_password(self): def test_invalid_empty_password(self):
"""Test invalid with empty password""" """Test invalid with empty password"""
self.assertIsNone(token_from_header("Basic :".encode())) with self.assertRaises(AuthenticationFailed):
token_from_header("Basic :".encode())
def test_invalid_no_token(self): def test_invalid_no_token(self):
"""Test invalid with no token""" """Test invalid with no token"""
with self.assertRaises(AuthenticationFailed):
auth = b64encode(":abc".encode()).decode() auth = b64encode(":abc".encode()).decode()
self.assertIsNone(token_from_header(f"Basic :{auth}".encode())) self.assertIsNone(token_from_header(f"Basic :{auth}".encode()))

View File

@ -1,6 +1,7 @@
"""Channels base classes""" """Channels base classes"""
from channels.exceptions import DenyConnection from channels.exceptions import DenyConnection
from channels.generic.websocket import JsonWebsocketConsumer from channels.generic.websocket import JsonWebsocketConsumer
from rest_framework.exceptions import AuthenticationFailed
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.api.auth import token_from_header from authentik.api.auth import token_from_header
@ -22,9 +23,13 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
raw_header = headers[b"authorization"] raw_header = headers[b"authorization"]
try:
token = token_from_header(raw_header) token = token_from_header(raw_header)
# token is only None when no header was given, in which case we deny too
if not token: if not token:
LOGGER.warning("Failed to authenticate") raise DenyConnection()
except AuthenticationFailed as exc:
LOGGER.warning("Failed to authenticate", exc=exc)
raise DenyConnection() raise DenyConnection()
self.user = token.user self.user = token.user