Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt 2023-12-07 09:39:49 +01:00
parent 2724d1d85c
commit a4477a9bea
No known key found for this signature in database
GPG key ID: 9C3FA22FABF1AA8D
6 changed files with 13 additions and 10 deletions

View file

@ -2,10 +2,8 @@
from hmac import compare_digest
from django.http import HttpResponseNotFound
from django_tenants.utils import get_tenant
from rest_framework import permissions
from rest_framework.authentication import get_authorization_header
from rest_framework.fields import ReadOnlyField
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.generics import RetrieveUpdateAPIView
from rest_framework.permissions import IsAdminUser

View file

@ -1,3 +1,4 @@
"""authentik tenants system checks"""
from django.core.checks import Error, register
from authentik.lib.config import CONFIG
@ -5,13 +6,15 @@ from authentik.lib.config import CONFIG
@register()
def check_embedded_outpost_disabled(app_configs, **kwargs):
"""Check that when the tenants API is enabled, the embedded outpost is disabled"""
if CONFIG.get_bool("tenants.enabled", False) and not CONFIG.get_bool(
"outposts.disable_embedded_outpost"
):
return [
Error(
"Embedded outpost must be disabled when tenants API is enabled.",
hint="Disable embedded outpost by setting outposts.disable_embedded_outpost to False, or disable the tenants API by setting tenants.enabled to False",
hint="Disable embedded outpost by setting outposts.disable_embedded_outpost to "
"False, or disable the tenants API by setting tenants.enabled to False",
)
]
return []

View file

@ -1,3 +1,4 @@
"""authentik tenants management command utils"""
from django.core.management.base import BaseCommand
from django.db import connection
from django_tenants.utils import get_public_schema_name
@ -23,6 +24,7 @@ class TenantCommand(BaseCommand):
def handle(self, *args, **options):
verbosity = int(options.get("verbosity"))
# pylint: disable=no-member
schema_name = options["schema_name"] or self.schema_name
connection.set_schema_to_public()
if verbosity >= 1:
@ -35,6 +37,7 @@ class TenantCommand(BaseCommand):
self.handle_per_tenant(*args, **options)
def handle_per_tenant(self, *args, **options):
"""The actual logic of the command."""
raise NotImplementedError(
"subclasses of TenantCommand must provide a handle_per_tenant() method"
)

View file

@ -25,7 +25,8 @@ def _validate_schema_name(name):
if not VALID_SCHEMA_NAME.match(name):
raise ValidationError(
_(
"Schema name must start with t_, only contain lowercase letters and numbers and be less than 63 characters."
"Schema name must start with t_, only contain lowercase letters and numbers and "
"be less than 63 characters."
)
)

View file

@ -1,11 +1,9 @@
"""Test Settings API"""
from json import loads
from django.urls import reverse
from django_tenants.utils import get_public_schema_name
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.config import CONFIG
from authentik.lib.generators import generate_id
from authentik.tenants.models import Domain, Tenant
from authentik.tenants.tests.utils import TenantAPITestCase

View file

@ -4,7 +4,7 @@ from rest_framework.test import APITransactionTestCase
class TenantAPITestCase(APITransactionTestCase):
# Overriden to force TRUNCATE CASCADE
# Overridden to force TRUNCATE CASCADE
def _fixture_teardown(self):
for db_name in self._databases_names(include_mirrors=False):
call_command(
@ -32,7 +32,7 @@ class TenantAPITestCase(APITransactionTestCase):
def assertSchemaExists(self, schema_name):
with connection.cursor() as cursor:
cursor.execute(
f"SELECT * FROM information_schema.schemata WHERE schema_name = %(schema_name)s",
"SELECT * FROM information_schema.schemata WHERE schema_name = %(schema_name)s", # nosec
{"schema_name": schema_name},
)
self.assertEqual(cursor.rowcount, 1)
@ -42,7 +42,7 @@ class TenantAPITestCase(APITransactionTestCase):
)
expected_tables = cursor.rowcount
cursor.execute(
f"SELECT * FROM information_schema.tables WHERE table_schema = %(schema_name)s",
"SELECT * FROM information_schema.tables WHERE table_schema = %(schema_name)s", # nosec
{"schema_name": schema_name},
)
self.assertEqual(cursor.rowcount, expected_tables)
@ -50,7 +50,7 @@ class TenantAPITestCase(APITransactionTestCase):
def assertSchemaDoesntExist(self, schema_name):
with connection.cursor() as cursor:
cursor.execute(
f"SELECT * FROM information_schema.schemata WHERE schema_name = %(schema_name)s",
"SELECT * FROM information_schema.schemata WHERE schema_name = %(schema_name)s", # nosec
{"schema_name": schema_name},
)
self.assertEqual(cursor.rowcount, 0)