sources/ldap: improve scalability (#6056)

* sources/ldap: improve scalability

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix lint

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use cache instead of call signature for page data

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-06-28 17:13:42 +02:00 committed by GitHub
parent a987846c76
commit e712225ced
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 167 additions and 62 deletions

View file

@ -70,8 +70,10 @@ class TaskInfo:
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*"))
@staticmethod
def by_name(name: str) -> Optional["TaskInfo"]:
def by_name(name: str) -> Optional["TaskInfo"] | Optional[list["TaskInfo"]]:
"""Get TaskInfo Object by name"""
if "*" in name:
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values()
return cache.get(CACHE_KEY_PREFIX + name, None)
def delete(self):

View file

@ -118,10 +118,9 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
"""Get source's sync status"""
source = self.get_object()
results = []
for sync_class in SYNC_CLASSES:
sync_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower()
task = TaskInfo.by_name(f"ldap_sync:{source.slug}:{sync_name}")
if task:
tasks = TaskInfo.by_name(f"ldap_sync:{source.slug}:*")
if tasks:
for task in tasks:
results.append(task)
return Response(TaskSerializer(results, many=True).data)
@ -143,7 +142,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
source = self.get_object()
all_objects = {}
for sync_class in SYNC_CLASSES:
class_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower()
class_name = sync_class.name()
all_objects.setdefault(class_name, [])
for obj in sync_class(source).get_objects(size_limit=10):
obj: dict

View file

@ -2,9 +2,8 @@
from django.core.management.base import BaseCommand
from structlog.stdlib import get_logger
from authentik.lib.utils.reflection import class_to_path
from authentik.sources.ldap.models import LDAPSource
from authentik.sources.ldap.tasks import SYNC_CLASSES, ldap_sync
from authentik.sources.ldap.tasks import ldap_sync_single
LOGGER = get_logger()
@ -21,7 +20,4 @@ class Command(BaseCommand):
if not source:
LOGGER.warning("Source does not exist", slug=source_slug)
continue
for sync_class in SYNC_CLASSES:
LOGGER.info("Starting sync", cls=sync_class)
# pylint: disable=no-value-for-parameter
ldap_sync(source.pk, class_to_path(sync_class))
ldap_sync_single(source)

View file

@ -12,13 +12,9 @@ from authentik.core.models import User
from authentik.core.signals import password_changed
from authentik.events.models import Event, EventAction
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.lib.utils.reflection import class_to_path
from authentik.sources.ldap.models import LDAPSource
from authentik.sources.ldap.password import LDAPPasswordChanger
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tasks import ldap_sync
from authentik.sources.ldap.tasks import ldap_sync_single
from authentik.stages.prompt.signals import password_validate
LOGGER = get_logger()
@ -35,12 +31,7 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
# and the mappings are created with an m2m event
if not instance.property_mappings.exists() or not instance.property_mappings_group.exists():
return
for sync_class in [
UserLDAPSynchronizer,
GroupLDAPSynchronizer,
MembershipLDAPSynchronizer,
]:
ldap_sync.delay(instance.pk, class_to_path(sync_class))
ldap_sync_single.delay(instance.pk)
@receiver(password_validate)

View file

@ -1,9 +1,10 @@
"""Sync LDAP Users and groups into authentik"""
from typing import Any, Generator
from django.conf import settings
from django.db.models.base import Model
from django.db.models.query import QuerySet
from ldap3 import Connection
from ldap3 import DEREF_ALWAYS, SUBTREE, Connection
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.exceptions import PropertyMappingExpressionException
@ -29,6 +30,24 @@ class BaseLDAPSynchronizer:
self._messages = []
self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__)
@staticmethod
def name() -> str:
"""UI name for the type of object this class synchronizes"""
raise NotImplementedError
def sync_full(self):
"""Run full sync, this function should only be used in tests"""
if not settings.TEST: # noqa
raise RuntimeError(
f"{self.__class__.__name__}.sync_full() should only be used in tests"
)
for page in self.get_objects():
self.sync(page)
def sync(self, page_data: list) -> int:
"""Sync function, implemented in subclass"""
raise NotImplementedError()
@property
def messages(self) -> list[str]:
"""Get all UI messages"""
@ -60,9 +79,47 @@ class BaseLDAPSynchronizer:
"""Get objects from LDAP, implemented in subclass"""
raise NotImplementedError()
def sync(self) -> int:
"""Sync function, implemented in subclass"""
raise NotImplementedError()
# pylint: disable=too-many-arguments
def search_paginator(
self,
search_base,
search_filter,
search_scope=SUBTREE,
dereference_aliases=DEREF_ALWAYS,
attributes=None,
size_limit=0,
time_limit=0,
types_only=False,
get_operational_attributes=False,
controls=None,
paged_size=5,
paged_criticality=False,
):
"""Search in pages, returns each page"""
cookie = True
while cookie:
self._connection.search(
search_base,
search_filter,
search_scope,
dereference_aliases,
attributes,
size_limit,
time_limit,
types_only,
get_operational_attributes,
controls,
paged_size,
paged_criticality,
None if cookie is True else cookie,
)
try:
cookie = self._connection.result["controls"]["1.2.840.113556.1.4.319"]["value"][
"cookie"
]
except KeyError:
cookie = None
yield self._connection.response
def _flatten(self, value: Any) -> Any:
"""Flatten `value` if its a list"""

View file

@ -13,8 +13,12 @@ from authentik.sources.ldap.sync.base import LDAP_UNIQUENESS, BaseLDAPSynchroniz
class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
"""Sync LDAP Users and groups into authentik"""
@staticmethod
def name() -> str:
return "groups"
def get_objects(self, **kwargs) -> Generator:
return self._connection.extend.standard.paged_search(
return self.search_paginator(
search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter,
search_scope=SUBTREE,
@ -22,13 +26,13 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
**kwargs,
)
def sync(self) -> int:
def sync(self, page_data: list) -> int:
"""Iterate over all LDAP Groups and create authentik_core.Group instances"""
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
return -1
group_count = 0
for group in self.get_objects():
for group in page_data:
if "attributes" not in group:
continue
attributes = group.get("attributes", {})

View file

@ -19,8 +19,12 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
super().__init__(source)
self.group_cache: dict[str, Group] = {}
@staticmethod
def name() -> str:
return "membership"
def get_objects(self, **kwargs) -> Generator:
return self._connection.extend.standard.paged_search(
return self.search_paginator(
search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter,
search_scope=SUBTREE,
@ -32,13 +36,13 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
**kwargs,
)
def sync(self) -> int:
def sync(self, page_data: list) -> int:
"""Iterate over all Users and assign Groups using memberOf Field"""
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
return -1
membership_count = 0
for group in self.get_objects():
for group in page_data:
if "attributes" not in group:
continue
members = group.get("attributes", {}).get(self._source.group_membership_field, [])

View file

@ -15,8 +15,12 @@ from authentik.sources.ldap.sync.vendor.ms_ad import MicrosoftActiveDirectory
class UserLDAPSynchronizer(BaseLDAPSynchronizer):
"""Sync LDAP Users into authentik"""
@staticmethod
def name() -> str:
return "users"
def get_objects(self, **kwargs) -> Generator:
return self._connection.extend.standard.paged_search(
return self.search_paginator(
search_base=self.base_dn_users,
search_filter=self._source.user_object_filter,
search_scope=SUBTREE,
@ -24,13 +28,13 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
**kwargs,
)
def sync(self) -> int:
def sync(self, page_data: list) -> int:
"""Iterate over all LDAP Users and create authentik_core.User instances"""
if not self._source.sync_users:
self.message("User syncing is disabled for this Source")
return -1
user_count = 0
for user in self.get_objects():
for user in page_data:
if "attributes" not in user:
continue
attributes = user.get("attributes", {})

View file

@ -11,6 +11,10 @@ from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer
class FreeIPA(BaseLDAPSynchronizer):
"""FreeIPA-specific LDAP"""
@staticmethod
def name() -> str:
return "freeipa"
def get_objects(self, **kwargs) -> Generator:
yield None

View file

@ -42,6 +42,10 @@ class UserAccountControl(IntFlag):
class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
"""Microsoft-specific LDAP"""
@staticmethod
def name() -> str:
return "microsoft_ad"
def get_objects(self, **kwargs) -> Generator:
yield None

View file

@ -1,4 +1,8 @@
"""LDAP Sync tasks"""
from uuid import uuid4
from celery import chain, group
from django.core.cache import cache
from ldap3.core.exceptions import LDAPException
from structlog.stdlib import get_logger
@ -8,6 +12,7 @@ from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.root.celery import CELERY_APP
from authentik.sources.ldap.models import LDAPSource
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
@ -18,14 +23,43 @@ SYNC_CLASSES = [
GroupLDAPSynchronizer,
MembershipLDAPSynchronizer,
]
CACHE_KEY_PREFIX = "goauthentik.io/sources/ldap/page/"
@CELERY_APP.task()
def ldap_sync_all():
"""Sync all sources"""
for source in LDAPSource.objects.filter(enabled=True):
for sync_class in SYNC_CLASSES:
ldap_sync.delay(source.pk, class_to_path(sync_class))
ldap_sync_single(source)
@CELERY_APP.task()
def ldap_sync_single(source: LDAPSource):
"""Sync a single source"""
task = chain(
# User and group sync can happen at once, they have no dependencies on each other
group(
ldap_sync_paginator(source, UserLDAPSynchronizer)
+ ldap_sync_paginator(source, GroupLDAPSynchronizer),
),
# Membership sync needs to run afterwards
group(
ldap_sync_paginator(source, MembershipLDAPSynchronizer),
),
)
task()
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list:
"""Return a list of task signatures with LDAP pagination data"""
sync_inst: BaseLDAPSynchronizer = sync(source)
signatures = []
for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
cache.set(page_cache_key, page)
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
signatures.append(page_sync)
return signatures
@CELERY_APP.task(
@ -34,7 +68,7 @@ def ldap_sync_all():
soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
)
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str):
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source"""
self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours"))
try:
@ -43,11 +77,16 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str):
# Because the source couldn't be found, we don't have a UID
# to set the state with
return
sync = path_to_class(sync_class)
self.set_uid(f"{source.slug}:{sync.__name__.replace('LDAPSynchronizer', '').lower()}")
sync: type[BaseLDAPSynchronizer] = path_to_class(sync_class)
uid = page_cache_key.replace(CACHE_KEY_PREFIX, "")
self.set_uid(f"{source.slug}:{sync.name()}:{uid}")
try:
sync_inst = sync(source)
count = sync_inst.sync()
sync_inst: BaseLDAPSynchronizer = sync(source)
page = cache.get(page_cache_key)
if not page:
return
cache.touch(page_cache_key)
count = sync_inst.sync(page)
messages = sync_inst.messages
messages.append(f"Synced {count} objects.")
self.set_status(
@ -56,6 +95,7 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str):
messages,
)
)
cache.delete(page_cache_key)
except LDAPException as exc:
# No explicit event is created here as .set_status with an error will do that
LOGGER.warning(exception_to_string(exc))

View file

@ -43,7 +43,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=raw_conn)
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync()
user_sync.sync_full()
user = User.objects.get(username="user0_sn")
# auth_user_by_bind = Mock(return_value=user)
@ -71,7 +71,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync()
user_sync.sync_full()
user = User.objects.get(username="user0_sn")
auth_user_by_bind = Mock(return_value=user)
@ -98,7 +98,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync()
user_sync.sync_full()
user = User.objects.get(username="user0_sn")
auth_user_by_bind = Mock(return_value=user)

View file

@ -51,7 +51,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync()
user_sync.sync_full()
self.assertFalse(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists())
events = Event.objects.filter(
@ -87,7 +87,7 @@ class LDAPSyncTests(TestCase):
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync()
user_sync.sync_full()
user = User.objects.filter(username="user0_sn").first()
self.assertEqual(user.attributes["foo"], "bar")
self.assertFalse(user.is_active)
@ -106,7 +106,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync()
user_sync.sync_full()
self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists())
@ -128,9 +128,9 @@ class LDAPSyncTests(TestCase):
self.source.sync_parent_group = parent_group
self.source.save()
group_sync = GroupLDAPSynchronizer(self.source)
group_sync.sync()
group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync.sync()
membership_sync.sync_full()
group: Group = Group.objects.filter(name="test-group").first()
self.assertIsNotNone(group)
self.assertEqual(group.parent, parent_group)
@ -152,9 +152,9 @@ class LDAPSyncTests(TestCase):
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save()
group_sync = GroupLDAPSynchronizer(self.source)
group_sync.sync()
group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync.sync()
membership_sync.sync_full()
group = Group.objects.filter(name="group1")
self.assertTrue(group.exists())
@ -177,11 +177,11 @@ class LDAPSyncTests(TestCase):
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save()
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync()
user_sync.sync_full()
group_sync = GroupLDAPSynchronizer(self.source)
group_sync.sync()
group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync.sync()
membership_sync.sync_full()
# Test if membership mapping based on memberUid works.
posix_group = Group.objects.filter(name="group-posix").first()
self.assertTrue(posix_group.users.filter(name="user-posix").exists())

View file

@ -63,7 +63,7 @@ class TestSourceLDAPSamba(SeleniumTestCase):
source.property_mappings_group.set(
LDAPPropertyMapping.objects.filter(name="goauthentik.io/sources/ldap/default-name")
)
UserLDAPSynchronizer(source).sync()
UserLDAPSynchronizer(source).sync_full()
self.assertTrue(User.objects.filter(username="bob").exists())
self.assertTrue(User.objects.filter(username="james").exists())
self.assertTrue(User.objects.filter(username="john").exists())
@ -94,9 +94,9 @@ class TestSourceLDAPSamba(SeleniumTestCase):
source.property_mappings_group.set(
LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/default-name")
)
GroupLDAPSynchronizer(source).sync()
UserLDAPSynchronizer(source).sync()
MembershipLDAPSynchronizer(source).sync()
GroupLDAPSynchronizer(source).sync_full()
UserLDAPSynchronizer(source).sync_full()
MembershipLDAPSynchronizer(source).sync_full()
self.assertIsNotNone(User.objects.get(username="bob"))
self.assertIsNotNone(User.objects.get(username="james"))
self.assertIsNotNone(User.objects.get(username="john"))
@ -137,7 +137,7 @@ class TestSourceLDAPSamba(SeleniumTestCase):
source.property_mappings_group.set(
LDAPPropertyMapping.objects.filter(name="goauthentik.io/sources/ldap/default-name")
)
UserLDAPSynchronizer(source).sync()
UserLDAPSynchronizer(source).sync_full()
username = "bob"
password = generate_id()
result = self.container.exec_run(
@ -160,7 +160,7 @@ class TestSourceLDAPSamba(SeleniumTestCase):
)
self.assertEqual(result.exit_code, 0)
# Sync again
UserLDAPSynchronizer(source).sync()
UserLDAPSynchronizer(source).sync_full()
user.refresh_from_db()
# Since password in samba was checked, it should be invalidated here too
self.assertFalse(user.has_usable_password())