diff --git a/authentik/tenants/api.py b/authentik/tenants/api.py index d384d93b5..7304be653 100644 --- a/authentik/tenants/api.py +++ b/authentik/tenants/api.py @@ -1,7 +1,7 @@ """Serializer for tenants models""" from hmac import compare_digest -from django.http import Http404 +from django.http import HttpResponseNotFound from django_tenants.utils import get_tenant from rest_framework import permissions from rest_framework.authentication import get_authorization_header @@ -19,13 +19,15 @@ from authentik.lib.config import CONFIG from authentik.tenants.models import Domain, Tenant -class TenantManagementKeyPermission(permissions.BasePermission): - """Authentication based on tenant_management_key""" +class TenantApiKeyPermission(permissions.BasePermission): + """Authentication based on tenants.api_key""" def has_permission(self, request: Request, view: View) -> bool: + key = CONFIG.get("tenants.api_key", "") + if not key: + return False token = validate_auth(get_authorization_header(request)) - key = CONFIG.get("tenants.api_key") - if compare_digest("", key): + if token is None: return False return compare_digest(token, key) @@ -53,12 +55,13 @@ class TenantViewSet(ModelViewSet): "domains__domain", ] ordering = ["schema_name"] - permission_classes = [TenantManagementKeyPermission] + authentication_classes = [] + permission_classes = [TenantApiKeyPermission] filter_backends = [OrderingFilter, SearchFilter] def dispatch(self, request, *args, **kwargs): if not CONFIG.get_bool("tenants.enabled", True): - return Http404() + return HttpResponseNotFound() return super().dispatch(request, *args, **kwargs) @@ -81,9 +84,15 @@ class DomainViewSet(ModelViewSet): "tenant__schema_name", ] ordering = ["domain"] - permission_classes = [TenantManagementKeyPermission] + authentication_classes = [] + permission_classes = [TenantApiKeyPermission] filter_backends = [OrderingFilter, SearchFilter] + def dispatch(self, request, *args, **kwargs): + if not CONFIG.get_bool("tenants.enabled", True): + return HttpResponseNotFound() + return super().dispatch(request, *args, **kwargs) + class SettingsSerializer(ModelSerializer): """Settings Serializer""" diff --git a/authentik/tenants/migrations/0001_initial.py b/authentik/tenants/migrations/0001_initial.py index b9cdf685d..d0feb55bd 100644 --- a/authentik/tenants/migrations/0001_initial.py +++ b/authentik/tenants/migrations/0001_initial.py @@ -6,6 +6,7 @@ import django.db.models.deletion import django_tenants.postgresql_backend.base from django.db import migrations, models +import authentik.tenants.models from authentik.lib.config import CONFIG @@ -42,7 +43,7 @@ class Migration(migrations.Migration): db_index=True, max_length=63, unique=True, - validators=[django_tenants.postgresql_backend.base._check_schema_name], + validators=[authentik.tenants.models._validate_schema_name], ), ), ( diff --git a/authentik/tenants/models.py b/authentik/tenants/models.py index dc9ba507f..843cda256 100644 --- a/authentik/tenants/models.py +++ b/authentik/tenants/models.py @@ -1,7 +1,9 @@ """Tenant models""" +import re from uuid import uuid4 from django.apps import apps +from django.core.exceptions import ValidationError from django.db import models from django.db.utils import IntegrityError from django.dispatch import receiver @@ -16,10 +18,25 @@ from authentik.lib.models import SerializerModel LOGGER = get_logger() +VALID_SCHEMA_NAME = re.compile(r"^t_[a-z0-9]{1,61}$") + + +def _validate_schema_name(name): + if not VALID_SCHEMA_NAME.match(name): + raise ValidationError( + _( + "Schema name must start with t_, only contain lowercase letters and numbers and be less than 63 characters." + ) + ) + + class Tenant(TenantMixin, SerializerModel): """Tenant""" tenant_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) + schema_name = models.CharField( + max_length=63, unique=True, db_index=True, validators=[_validate_schema_name] + ) name = models.TextField() auto_create_schema = True diff --git a/authentik/tenants/tests/__init__.py b/authentik/tenants/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/authentik/tenants/tests/test_api.py b/authentik/tenants/tests/test_api.py new file mode 100644 index 000000000..f7e9b295a --- /dev/null +++ b/authentik/tenants/tests/test_api.py @@ -0,0 +1,118 @@ +"""Test Tenant API""" +from json import loads + +from django.core.management import call_command +from django.db import connection +from django.urls import reverse +from rest_framework.test import APILiveServerTestCase, APITestCase, APITransactionTestCase + +from authentik.lib.config import CONFIG +from authentik.lib.generators import generate_id + +TENANTS_API_KEY = generate_id() +HEADERS = {"Authorization": f"Bearer {TENANTS_API_KEY}"} + + +class TestAPI(APITransactionTestCase): + """Test api view""" + + def _fixture_teardown(self): + for db_name in self._databases_names(include_mirrors=False): + call_command( + "flush", + verbosity=0, + interactive=False, + database=db_name, + reset_sequences=False, + allow_cascade=True, + inhibit_post_migrate=False, + ) + + def setUp(self): + call_command("migrate_schemas", schema="template", tenant=True) + + def assertSchemaExists(self, schema_name): + with connection.cursor() as cursor: + cursor.execute( + f"SELECT * FROM information_schema.schemata WHERE schema_name = '{schema_name}';" + ) + self.assertEqual(cursor.rowcount, 1) + + cursor.execute( + "SELECT * FROM information_schema.tables WHERE table_schema = 'template';" + ) + expected_tables = cursor.rowcount + cursor.execute( + f"SELECT * FROM information_schema.tables WHERE table_schema = '{schema_name}';" + ) + self.assertEqual(cursor.rowcount, expected_tables) + + def assertSchemaDoesntExist(self, schema_name): + with connection.cursor() as cursor: + cursor.execute( + f"SELECT * FROM information_schema.schemata WHERE schema_name = '{schema_name}';" + ) + self.assertEqual(cursor.rowcount, 0) + + @CONFIG.patch("outposts.disable_embedded_outpost", True) + @CONFIG.patch("tenants.enabled", True) + @CONFIG.patch("tenants.api_key", TENANTS_API_KEY) + def test_tenant_create_delete(self): + """Test Tenant creation API Endpoint""" + response = self.client.post( + reverse( + "authentik_api:tenant-list", + ), + data={"name": generate_id(), "schema_name": "t_" + generate_id(length=63 - 2).lower()}, + headers=HEADERS, + ) + self.assertEqual(response.status_code, 201) + body = loads(response.content.decode()) + + self.assertSchemaExists(body["schema_name"]) + + response = self.client.delete( + reverse( + "authentik_api:tenant-detail", + kwargs={"pk": body["tenant_uuid"]}, + ), + headers=HEADERS, + ) + self.assertEqual(response.status_code, 204) + self.assertSchemaDoesntExist(body["schema_name"]) + + @CONFIG.patch("outposts.disable_embedded_outpost", True) + @CONFIG.patch("tenants.enabled", True) + @CONFIG.patch("tenants.api_key", TENANTS_API_KEY) + def test_unauthenticated(self): + """Test Tenant creation API Endpoint""" + response = self.client.get( + reverse( + "authentik_api:tenant-list", + ), + ) + self.assertEqual(response.status_code, 403) + + @CONFIG.patch("outposts.disable_embedded_outpost", True) + @CONFIG.patch("tenants.enabled", True) + @CONFIG.patch("tenants.api_key", "") + def test_no_api_key_configured(self): + """Test Tenant creation API Endpoint""" + response = self.client.get( + reverse( + "authentik_api:tenant-list", + ), + ) + self.assertEqual(response.status_code, 403) + + @CONFIG.patch("tenants.enabled", False) + @CONFIG.patch("tenants.api_key", TENANTS_API_KEY) + def test_api_disabled(self): + """Test Tenant creation API Endpoint""" + response = self.client.get( + reverse( + "authentik_api:tenant-list", + ), + headers=HEADERS, + ) + self.assertEqual(response.status_code, 404) diff --git a/lifecycle/system_migrations/tenant_to_brand.py b/lifecycle/system_migrations/tenant_to_brand.py index 5f288b224..7a9da6f74 100644 --- a/lifecycle/system_migrations/tenant_to_brand.py +++ b/lifecycle/system_migrations/tenant_to_brand.py @@ -13,7 +13,7 @@ COMMIT; class Migration(BaseMigration): def needs_migration(self) -> bool: self.cur.execute( - "select * from information_schema.tables where table_name =" " 'django_migrations';" + "select * from information_schema.tables where table_name = 'django_migrations';" ) # No migration table, assume new installation if not bool(self.cur.rowcount):