Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt 2023-12-07 09:39:49 +01:00
parent 2724d1d85c
commit a4477a9bea
No known key found for this signature in database
GPG key ID: 9C3FA22FABF1AA8D
6 changed files with 13 additions and 10 deletions

View file

@ -2,10 +2,8 @@
from hmac import compare_digest from hmac import compare_digest
from django.http import HttpResponseNotFound from django.http import HttpResponseNotFound
from django_tenants.utils import get_tenant
from rest_framework import permissions from rest_framework import permissions
from rest_framework.authentication import get_authorization_header from rest_framework.authentication import get_authorization_header
from rest_framework.fields import ReadOnlyField
from rest_framework.filters import OrderingFilter, SearchFilter from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.generics import RetrieveUpdateAPIView from rest_framework.generics import RetrieveUpdateAPIView
from rest_framework.permissions import IsAdminUser from rest_framework.permissions import IsAdminUser

View file

@ -1,3 +1,4 @@
"""authentik tenants system checks"""
from django.core.checks import Error, register from django.core.checks import Error, register
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -5,13 +6,15 @@ from authentik.lib.config import CONFIG
@register() @register()
def check_embedded_outpost_disabled(app_configs, **kwargs): def check_embedded_outpost_disabled(app_configs, **kwargs):
"""Check that when the tenants API is enabled, the embedded outpost is disabled"""
if CONFIG.get_bool("tenants.enabled", False) and not CONFIG.get_bool( if CONFIG.get_bool("tenants.enabled", False) and not CONFIG.get_bool(
"outposts.disable_embedded_outpost" "outposts.disable_embedded_outpost"
): ):
return [ return [
Error( Error(
"Embedded outpost must be disabled when tenants API is enabled.", "Embedded outpost must be disabled when tenants API is enabled.",
hint="Disable embedded outpost by setting outposts.disable_embedded_outpost to False, or disable the tenants API by setting tenants.enabled to False", hint="Disable embedded outpost by setting outposts.disable_embedded_outpost to "
"False, or disable the tenants API by setting tenants.enabled to False",
) )
] ]
return [] return []

View file

@ -1,3 +1,4 @@
"""authentik tenants management command utils"""
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import connection from django.db import connection
from django_tenants.utils import get_public_schema_name from django_tenants.utils import get_public_schema_name
@ -23,6 +24,7 @@ class TenantCommand(BaseCommand):
def handle(self, *args, **options): def handle(self, *args, **options):
verbosity = int(options.get("verbosity")) verbosity = int(options.get("verbosity"))
# pylint: disable=no-member
schema_name = options["schema_name"] or self.schema_name schema_name = options["schema_name"] or self.schema_name
connection.set_schema_to_public() connection.set_schema_to_public()
if verbosity >= 1: if verbosity >= 1:
@ -35,6 +37,7 @@ class TenantCommand(BaseCommand):
self.handle_per_tenant(*args, **options) self.handle_per_tenant(*args, **options)
def handle_per_tenant(self, *args, **options): def handle_per_tenant(self, *args, **options):
"""The actual logic of the command."""
raise NotImplementedError( raise NotImplementedError(
"subclasses of TenantCommand must provide a handle_per_tenant() method" "subclasses of TenantCommand must provide a handle_per_tenant() method"
) )

View file

@ -25,7 +25,8 @@ def _validate_schema_name(name):
if not VALID_SCHEMA_NAME.match(name): if not VALID_SCHEMA_NAME.match(name):
raise ValidationError( raise ValidationError(
_( _(
"Schema name must start with t_, only contain lowercase letters and numbers and be less than 63 characters." "Schema name must start with t_, only contain lowercase letters and numbers and "
"be less than 63 characters."
) )
) )

View file

@ -1,11 +1,9 @@
"""Test Settings API""" """Test Settings API"""
from json import loads
from django.urls import reverse from django.urls import reverse
from django_tenants.utils import get_public_schema_name from django_tenants.utils import get_public_schema_name
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.config import CONFIG
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.tenants.models import Domain, Tenant from authentik.tenants.models import Domain, Tenant
from authentik.tenants.tests.utils import TenantAPITestCase from authentik.tenants.tests.utils import TenantAPITestCase

View file

@ -4,7 +4,7 @@ from rest_framework.test import APITransactionTestCase
class TenantAPITestCase(APITransactionTestCase): class TenantAPITestCase(APITransactionTestCase):
# Overriden to force TRUNCATE CASCADE # Overridden to force TRUNCATE CASCADE
def _fixture_teardown(self): def _fixture_teardown(self):
for db_name in self._databases_names(include_mirrors=False): for db_name in self._databases_names(include_mirrors=False):
call_command( call_command(
@ -32,7 +32,7 @@ class TenantAPITestCase(APITransactionTestCase):
def assertSchemaExists(self, schema_name): def assertSchemaExists(self, schema_name):
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute( cursor.execute(
f"SELECT * FROM information_schema.schemata WHERE schema_name = %(schema_name)s", "SELECT * FROM information_schema.schemata WHERE schema_name = %(schema_name)s", # nosec
{"schema_name": schema_name}, {"schema_name": schema_name},
) )
self.assertEqual(cursor.rowcount, 1) self.assertEqual(cursor.rowcount, 1)
@ -42,7 +42,7 @@ class TenantAPITestCase(APITransactionTestCase):
) )
expected_tables = cursor.rowcount expected_tables = cursor.rowcount
cursor.execute( cursor.execute(
f"SELECT * FROM information_schema.tables WHERE table_schema = %(schema_name)s", "SELECT * FROM information_schema.tables WHERE table_schema = %(schema_name)s", # nosec
{"schema_name": schema_name}, {"schema_name": schema_name},
) )
self.assertEqual(cursor.rowcount, expected_tables) self.assertEqual(cursor.rowcount, expected_tables)
@ -50,7 +50,7 @@ class TenantAPITestCase(APITransactionTestCase):
def assertSchemaDoesntExist(self, schema_name): def assertSchemaDoesntExist(self, schema_name):
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute( cursor.execute(
f"SELECT * FROM information_schema.schemata WHERE schema_name = %(schema_name)s", "SELECT * FROM information_schema.schemata WHERE schema_name = %(schema_name)s", # nosec
{"schema_name": schema_name}, {"schema_name": schema_name},
) )
self.assertEqual(cursor.rowcount, 0) self.assertEqual(cursor.rowcount, 0)