core: enforce unique on names where it makes sense (#4866)

enforce unique on names where it makes sense

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-03-07 23:52:34 +01:00 committed by GitHub
parent 36f92f01de
commit 67f3db1e03
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 172 additions and 69 deletions

View file

@ -9,6 +9,7 @@ from authentik.blueprints.tests import reconcile_app
from authentik.core.models import Group, User
from authentik.core.tasks import clean_expired_models
from authentik.events.monitored_tasks import TaskResultStatus
from authentik.lib.generators import generate_id
class TestAdminAPI(TestCase):
@ -16,8 +17,8 @@ class TestAdminAPI(TestCase):
def setUp(self) -> None:
super().setUp()
self.user = User.objects.create(username="test-user")
self.group = Group.objects.create(name="superusers", is_superuser=True)
self.user = User.objects.create(username=generate_id())
self.group = Group.objects.create(name=generate_id(), is_superuser=True)
self.group.users.add(self.user)
self.group.save()
self.client.force_login(self.user)

View file

@ -4,6 +4,7 @@ from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase
from authentik.core.models import Application, User
from authentik.lib.generators import generate_id
class TestAPIDecorators(APITestCase):
@ -16,7 +17,7 @@ class TestAPIDecorators(APITestCase):
def test_obj_perm_denied(self):
"""Test object perm denied"""
self.client.force_login(self.user)
app = Application.objects.create(name="denied", slug="denied")
app = Application.objects.create(name=generate_id(), slug=generate_id())
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
@ -25,7 +26,7 @@ class TestAPIDecorators(APITestCase):
def test_other_perm_denied(self):
"""Test other perm denied"""
self.client.force_login(self.user)
app = Application.objects.create(name="denied", slug="denied")
app = Application.objects.create(name=generate_id(), slug=generate_id())
assign_perm("authentik_core.view_application", self.user, app)
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})

View file

@ -0,0 +1,26 @@
# Generated by Django 4.1.7 on 2023-03-07 13:41
from django.db import migrations, models
from authentik.lib.migrations import fallback_names
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0025_alter_provider_authorization_flow"),
]
operations = [
migrations.RunPython(fallback_names("authentik_core", "propertymapping", "name")),
migrations.RunPython(fallback_names("authentik_core", "provider", "name")),
migrations.AlterField(
model_name="propertymapping",
name="name",
field=models.TextField(unique=True),
),
migrations.AlterField(
model_name="provider",
name="name",
field=models.TextField(unique=True),
),
]

View file

@ -243,7 +243,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
class Provider(SerializerModel):
"""Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application"""
name = models.TextField()
name = models.TextField(unique=True)
authorization_flow = models.ForeignKey(
"authentik_flows.Flow",
@ -608,7 +608,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
"""User-defined key -> x mapping which can be used by providers to expose extra data."""
pm_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
name = models.TextField()
name = models.TextField(unique=True)
expression = models.TextField()
objects = InheritanceManager()

View file

@ -2,7 +2,6 @@
import uuid
from datetime import timedelta
from typing import Iterable
import django.db.models.deletion
from django.apps.registry import Apps
@ -13,6 +12,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
import authentik.events.models
import authentik.lib.models
from authentik.events.models import EventAction, NotificationSeverity, TransportMode
from authentik.lib.migrations import progress_bar
def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
@ -43,49 +43,6 @@ def token_view_to_secret_view(apps: Apps, schema_editor: BaseDatabaseSchemaEdito
Event.objects.using(db_alias).bulk_update(events, ["context", "action"])
# Taken from https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
def progress_bar(
iterable: Iterable,
prefix="Writing: ",
suffix=" finished",
decimals=1,
length=100,
fill="",
print_end="\r",
):
"""
Call in a loop to create terminal progress bar
@params:
iteration - Required : current iteration (Int)
total - Required : total iterations (Int)
prefix - Optional : prefix string (Str)
suffix - Optional : suffix string (Str)
decimals - Optional : positive number of decimals in percent complete (Int)
length - Optional : character length of bar (Int)
fill - Optional : bar fill character (Str)
print_end - Optional : end character (e.g. "\r", "\r\n") (Str)
"""
total = len(iterable)
if total < 1:
return
def print_progress_bar(iteration):
"""Progress Bar Printing Function"""
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
filledLength = int(length * iteration // total)
bar = fill * filledLength + "-" * (length - filledLength)
print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
# Initial Call
print_progress_bar(0)
# Update Progress Bar
for i, item in enumerate(iterable):
yield item
print_progress_bar(i + 1)
# Print New Line on Complete
print()
def update_expires(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Event = apps.get_model("authentik_events", "event")

View file

@ -0,0 +1,58 @@
"""Migration helpers"""
from typing import Iterable
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def fallback_names(app: str, model: str, field: str):
"""Factory function that checks all instances of `app`.`model` instance's `field`
to prevent any duplicates"""
def migrator(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
klass = apps.get_model(app, model)
seen_names = []
for obj in klass.objects.using(db_alias).all():
value = getattr(obj, field)
if value not in seen_names:
seen_names.append(value)
continue
new_value = value + "_2"
setattr(obj, field, new_value)
obj.save()
return migrator
def progress_bar(iterable: Iterable):
"""Call in a loop to create terminal progress bar
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console"""
prefix = "Writing: "
suffix = " finished"
decimals = 1
length = 100
fill = ""
print_end = "\r"
total = len(iterable)
if total < 1:
return
def print_progress_bar(iteration):
"""Progress Bar Printing Function"""
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
filled_length = int(length * iteration // total)
bar = fill * filled_length + "-" * (length - filled_length)
print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
# Initial Call
print_progress_bar(0)
# Update Progress Bar
for i, item in enumerate(iterable):
yield item
print_progress_bar(i + 1)
# Print New Line on Complete
print()

View file

@ -0,0 +1,28 @@
# Generated by Django 4.1.7 on 2023-03-07 13:41
from django.db import migrations, models
from authentik.lib.migrations import fallback_names
class Migration(migrations.Migration):
dependencies = [
("authentik_outposts", "0018_kubernetesserviceconnection_verify_ssl"),
]
operations = [
migrations.RunPython(fallback_names("authentik_outposts", "outpost", "name")),
migrations.RunPython(
fallback_names("authentik_outposts", "outpostserviceconnection", "name")
),
migrations.AlterField(
model_name="outpost",
name="name",
field=models.TextField(unique=True),
),
migrations.AlterField(
model_name="outpostserviceconnection",
name="name",
field=models.TextField(unique=True),
),
]

View file

@ -113,7 +113,7 @@ class OutpostServiceConnection(models.Model):
"""Connection details for an Outpost Controller, like Docker or Kubernetes"""
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
name = models.TextField()
name = models.TextField(unique=True)
local = models.BooleanField(
default=False,
@ -239,7 +239,7 @@ class Outpost(SerializerModel, ManagedModel):
"""Outpost instance which manages a service user and token"""
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
name = models.TextField()
name = models.TextField(unique=True)
type = models.TextField(choices=OutpostType.choices, default=OutpostType.PROXY)
service_connection = InheritanceForeignKey(

View file

@ -4,6 +4,7 @@ from rest_framework.test import APITestCase
from authentik.core.models import PropertyMapping
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.outposts.api.outposts import OutpostSerializer
from authentik.outposts.models import OutpostType, default_outpost_config
from authentik.providers.ldap.models import LDAPProvider
@ -16,7 +17,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
def setUp(self) -> None:
super().setUp()
self.mapping = PropertyMapping.objects.create(
name="dummy", expression="""return {'foo': 'bar'}"""
name=generate_id(), expression="""return {'foo': 'bar'}"""
)
self.user = create_test_admin_user()
self.client.force_login(self.user)
@ -25,12 +26,12 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
"""Test Outpost validation"""
valid = OutpostSerializer(
data={
"name": "foo",
"name": generate_id(),
"type": OutpostType.PROXY,
"config": default_outpost_config(),
"providers": [
ProxyProvider.objects.create(
name="test", authorization_flow=create_test_flow()
name=generate_id(), authorization_flow=create_test_flow()
).pk
],
}
@ -38,12 +39,12 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
self.assertTrue(valid.is_valid())
invalid = OutpostSerializer(
data={
"name": "foo",
"name": generate_id(),
"type": OutpostType.PROXY,
"config": default_outpost_config(),
"providers": [
LDAPProvider.objects.create(
name="test", authorization_flow=create_test_flow()
name=generate_id(), authorization_flow=create_test_flow()
).pk
],
}
@ -60,15 +61,19 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
def test_outpost_config(self):
"""Test Outpost's config field"""
provider = ProxyProvider.objects.create(name="test", authorization_flow=create_test_flow())
invalid = OutpostSerializer(data={"name": "foo", "providers": [provider.pk], "config": ""})
provider = ProxyProvider.objects.create(
name=generate_id(), authorization_flow=create_test_flow()
)
invalid = OutpostSerializer(
data={"name": generate_id(), "providers": [provider.pk], "config": ""}
)
self.assertFalse(invalid.is_valid())
self.assertIn("config", invalid.errors)
valid = OutpostSerializer(
data={
"name": "foo",
"name": generate_id(),
"providers": [provider.pk],
"config": default_outpost_config("foo"),
"config": default_outpost_config(generate_id()),
"type": OutpostType.PROXY,
}
)

View file

@ -0,0 +1,20 @@
# Generated by Django 4.1.7 on 2023-03-07 13:41
from django.db import migrations, models
from authentik.lib.migrations import fallback_names
class Migration(migrations.Migration):
dependencies = [
("authentik_policies", "0009_alter_policy_name"),
]
operations = [
migrations.RunPython(fallback_names("authentik_policies", "policy", "name")),
migrations.AlterField(
model_name="policy",
name="name",
field=models.TextField(unique=True),
),
]

View file

@ -158,7 +158,7 @@ class Policy(SerializerModel, CreatedUpdatedModel):
policy_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
name = models.TextField()
name = models.TextField(unique=True)
execution_logging = models.BooleanField(
default=False,

View file

@ -2,7 +2,8 @@
from django.core.cache import cache
from django.test import TestCase
from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.engine import PolicyEngine
from authentik.policies.exceptions import PolicyEngineException
@ -17,11 +18,17 @@ class TestPolicyEngine(TestCase):
def setUp(self):
clear_policy_cache()
self.user = User.objects.create_user(username="policyuser")
self.policy_false = DummyPolicy.objects.create(result=False, wait_min=0, wait_max=1)
self.policy_true = DummyPolicy.objects.create(result=True, wait_min=0, wait_max=1)
self.policy_wrong_type = Policy.objects.create(name="wrong_type")
self.policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}")
self.user = create_test_admin_user()
self.policy_false = DummyPolicy.objects.create(
name=generate_id(), result=False, wait_min=0, wait_max=1
)
self.policy_true = DummyPolicy.objects.create(
name=generate_id(), result=True, wait_min=0, wait_max=1
)
self.policy_wrong_type = Policy.objects.create(name=generate_id())
self.policy_raises = ExpressionPolicy.objects.create(
name=generate_id(), expression="{{ 0/0 }}"
)
def test_engine_empty(self):
"""Ensure empty policy list passes"""

View file

@ -60,7 +60,7 @@ exclude_lines = [
show_missing = true
[tool.pylint.basic]
good-names = ["pk", "id", "i", "j", "k", "_"]
good-names = ["pk", "id", "i", "j", "k", "_", "bar"]
[tool.pylint.master]
disable = [