diff --git a/authentik/events/monitored_tasks.py b/authentik/events/monitored_tasks.py index 5db4febef..3781c30d5 100644 --- a/authentik/events/monitored_tasks.py +++ b/authentik/events/monitored_tasks.py @@ -70,8 +70,10 @@ class TaskInfo: return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*")) @staticmethod - def by_name(name: str) -> Optional["TaskInfo"]: + def by_name(name: str) -> Optional["TaskInfo"] | Optional[list["TaskInfo"]]: """Get TaskInfo Object by name""" + if "*" in name: + return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values() return cache.get(CACHE_KEY_PREFIX + name, None) def delete(self): diff --git a/authentik/sources/ldap/api.py b/authentik/sources/ldap/api.py index 0a8849345..f81b4ad4e 100644 --- a/authentik/sources/ldap/api.py +++ b/authentik/sources/ldap/api.py @@ -118,10 +118,9 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet): """Get source's sync status""" source = self.get_object() results = [] - for sync_class in SYNC_CLASSES: - sync_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower() - task = TaskInfo.by_name(f"ldap_sync:{source.slug}:{sync_name}") - if task: + tasks = TaskInfo.by_name(f"ldap_sync:{source.slug}:*") + if tasks: + for task in tasks: results.append(task) return Response(TaskSerializer(results, many=True).data) @@ -143,7 +142,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet): source = self.get_object() all_objects = {} for sync_class in SYNC_CLASSES: - class_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower() + class_name = sync_class.name() all_objects.setdefault(class_name, []) for obj in sync_class(source).get_objects(size_limit=10): obj: dict diff --git a/authentik/sources/ldap/management/commands/ldap_sync.py b/authentik/sources/ldap/management/commands/ldap_sync.py index 3e52fc5c7..15ea1c45a 100644 --- a/authentik/sources/ldap/management/commands/ldap_sync.py +++ b/authentik/sources/ldap/management/commands/ldap_sync.py @@ -2,9 +2,8 @@ from django.core.management.base import BaseCommand from structlog.stdlib import get_logger -from authentik.lib.utils.reflection import class_to_path from authentik.sources.ldap.models import LDAPSource -from authentik.sources.ldap.tasks import SYNC_CLASSES, ldap_sync +from authentik.sources.ldap.tasks import ldap_sync_single LOGGER = get_logger() @@ -21,7 +20,4 @@ class Command(BaseCommand): if not source: LOGGER.warning("Source does not exist", slug=source_slug) continue - for sync_class in SYNC_CLASSES: - LOGGER.info("Starting sync", cls=sync_class) - # pylint: disable=no-value-for-parameter - ldap_sync(source.pk, class_to_path(sync_class)) + ldap_sync_single(source) diff --git a/authentik/sources/ldap/signals.py b/authentik/sources/ldap/signals.py index 81d8c692e..a5f7ea037 100644 --- a/authentik/sources/ldap/signals.py +++ b/authentik/sources/ldap/signals.py @@ -12,13 +12,9 @@ from authentik.core.models import User from authentik.core.signals import password_changed from authentik.events.models import Event, EventAction from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER -from authentik.lib.utils.reflection import class_to_path from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.password import LDAPPasswordChanger -from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer -from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer -from authentik.sources.ldap.sync.users import UserLDAPSynchronizer -from authentik.sources.ldap.tasks import ldap_sync +from authentik.sources.ldap.tasks import ldap_sync_single from authentik.stages.prompt.signals import password_validate LOGGER = get_logger() @@ -35,12 +31,7 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_): # and the mappings are created with an m2m event if not instance.property_mappings.exists() or not instance.property_mappings_group.exists(): return - for sync_class in [ - UserLDAPSynchronizer, - GroupLDAPSynchronizer, - MembershipLDAPSynchronizer, - ]: - ldap_sync.delay(instance.pk, class_to_path(sync_class)) + ldap_sync_single.delay(instance.pk) @receiver(password_validate) diff --git a/authentik/sources/ldap/sync/base.py b/authentik/sources/ldap/sync/base.py index 4ab18d179..235c7be26 100644 --- a/authentik/sources/ldap/sync/base.py +++ b/authentik/sources/ldap/sync/base.py @@ -1,9 +1,10 @@ """Sync LDAP Users and groups into authentik""" from typing import Any, Generator +from django.conf import settings from django.db.models.base import Model from django.db.models.query import QuerySet -from ldap3 import Connection +from ldap3 import DEREF_ALWAYS, SUBTREE, Connection from structlog.stdlib import BoundLogger, get_logger from authentik.core.exceptions import PropertyMappingExpressionException @@ -29,6 +30,24 @@ class BaseLDAPSynchronizer: self._messages = [] self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__) + @staticmethod + def name() -> str: + """UI name for the type of object this class synchronizes""" + raise NotImplementedError + + def sync_full(self): + """Run full sync, this function should only be used in tests""" + if not settings.TEST: # noqa + raise RuntimeError( + f"{self.__class__.__name__}.sync_full() should only be used in tests" + ) + for page in self.get_objects(): + self.sync(page) + + def sync(self, page_data: list) -> int: + """Sync function, implemented in subclass""" + raise NotImplementedError() + @property def messages(self) -> list[str]: """Get all UI messages""" @@ -60,9 +79,47 @@ class BaseLDAPSynchronizer: """Get objects from LDAP, implemented in subclass""" raise NotImplementedError() - def sync(self) -> int: - """Sync function, implemented in subclass""" - raise NotImplementedError() + # pylint: disable=too-many-arguments + def search_paginator( + self, + search_base, + search_filter, + search_scope=SUBTREE, + dereference_aliases=DEREF_ALWAYS, + attributes=None, + size_limit=0, + time_limit=0, + types_only=False, + get_operational_attributes=False, + controls=None, + paged_size=5, + paged_criticality=False, + ): + """Search in pages, returns each page""" + cookie = True + while cookie: + self._connection.search( + search_base, + search_filter, + search_scope, + dereference_aliases, + attributes, + size_limit, + time_limit, + types_only, + get_operational_attributes, + controls, + paged_size, + paged_criticality, + None if cookie is True else cookie, + ) + try: + cookie = self._connection.result["controls"]["1.2.840.113556.1.4.319"]["value"][ + "cookie" + ] + except KeyError: + cookie = None + yield self._connection.response def _flatten(self, value: Any) -> Any: """Flatten `value` if its a list""" diff --git a/authentik/sources/ldap/sync/groups.py b/authentik/sources/ldap/sync/groups.py index 2508bd979..a10e0a904 100644 --- a/authentik/sources/ldap/sync/groups.py +++ b/authentik/sources/ldap/sync/groups.py @@ -13,8 +13,12 @@ from authentik.sources.ldap.sync.base import LDAP_UNIQUENESS, BaseLDAPSynchroniz class GroupLDAPSynchronizer(BaseLDAPSynchronizer): """Sync LDAP Users and groups into authentik""" + @staticmethod + def name() -> str: + return "groups" + def get_objects(self, **kwargs) -> Generator: - return self._connection.extend.standard.paged_search( + return self.search_paginator( search_base=self.base_dn_groups, search_filter=self._source.group_object_filter, search_scope=SUBTREE, @@ -22,13 +26,13 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): **kwargs, ) - def sync(self) -> int: + def sync(self, page_data: list) -> int: """Iterate over all LDAP Groups and create authentik_core.Group instances""" if not self._source.sync_groups: self.message("Group syncing is disabled for this Source") return -1 group_count = 0 - for group in self.get_objects(): + for group in page_data: if "attributes" not in group: continue attributes = group.get("attributes", {}) diff --git a/authentik/sources/ldap/sync/membership.py b/authentik/sources/ldap/sync/membership.py index 1bb0c8515..432715c5a 100644 --- a/authentik/sources/ldap/sync/membership.py +++ b/authentik/sources/ldap/sync/membership.py @@ -19,8 +19,12 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): super().__init__(source) self.group_cache: dict[str, Group] = {} + @staticmethod + def name() -> str: + return "membership" + def get_objects(self, **kwargs) -> Generator: - return self._connection.extend.standard.paged_search( + return self.search_paginator( search_base=self.base_dn_groups, search_filter=self._source.group_object_filter, search_scope=SUBTREE, @@ -32,13 +36,13 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): **kwargs, ) - def sync(self) -> int: + def sync(self, page_data: list) -> int: """Iterate over all Users and assign Groups using memberOf Field""" if not self._source.sync_groups: self.message("Group syncing is disabled for this Source") return -1 membership_count = 0 - for group in self.get_objects(): + for group in page_data: if "attributes" not in group: continue members = group.get("attributes", {}).get(self._source.group_membership_field, []) diff --git a/authentik/sources/ldap/sync/users.py b/authentik/sources/ldap/sync/users.py index 053fd6142..c55d14517 100644 --- a/authentik/sources/ldap/sync/users.py +++ b/authentik/sources/ldap/sync/users.py @@ -15,8 +15,12 @@ from authentik.sources.ldap.sync.vendor.ms_ad import MicrosoftActiveDirectory class UserLDAPSynchronizer(BaseLDAPSynchronizer): """Sync LDAP Users into authentik""" + @staticmethod + def name() -> str: + return "users" + def get_objects(self, **kwargs) -> Generator: - return self._connection.extend.standard.paged_search( + return self.search_paginator( search_base=self.base_dn_users, search_filter=self._source.user_object_filter, search_scope=SUBTREE, @@ -24,13 +28,13 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): **kwargs, ) - def sync(self) -> int: + def sync(self, page_data: list) -> int: """Iterate over all LDAP Users and create authentik_core.User instances""" if not self._source.sync_users: self.message("User syncing is disabled for this Source") return -1 user_count = 0 - for user in self.get_objects(): + for user in page_data: if "attributes" not in user: continue attributes = user.get("attributes", {}) diff --git a/authentik/sources/ldap/sync/vendor/freeipa.py b/authentik/sources/ldap/sync/vendor/freeipa.py index f71f9778e..a56b569f1 100644 --- a/authentik/sources/ldap/sync/vendor/freeipa.py +++ b/authentik/sources/ldap/sync/vendor/freeipa.py @@ -11,6 +11,10 @@ from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer class FreeIPA(BaseLDAPSynchronizer): """FreeIPA-specific LDAP""" + @staticmethod + def name() -> str: + return "freeipa" + def get_objects(self, **kwargs) -> Generator: yield None diff --git a/authentik/sources/ldap/sync/vendor/ms_ad.py b/authentik/sources/ldap/sync/vendor/ms_ad.py index a78b5fddb..c14e9f944 100644 --- a/authentik/sources/ldap/sync/vendor/ms_ad.py +++ b/authentik/sources/ldap/sync/vendor/ms_ad.py @@ -42,6 +42,10 @@ class UserAccountControl(IntFlag): class MicrosoftActiveDirectory(BaseLDAPSynchronizer): """Microsoft-specific LDAP""" + @staticmethod + def name() -> str: + return "microsoft_ad" + def get_objects(self, **kwargs) -> Generator: yield None diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index 980018a9e..04348188c 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -1,4 +1,8 @@ """LDAP Sync tasks""" +from uuid import uuid4 + +from celery import chain, group +from django.core.cache import cache from ldap3.core.exceptions import LDAPException from structlog.stdlib import get_logger @@ -8,6 +12,7 @@ from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.reflection import class_to_path, path_to_class from authentik.root.celery import CELERY_APP from authentik.sources.ldap.models import LDAPSource +from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer @@ -18,14 +23,43 @@ SYNC_CLASSES = [ GroupLDAPSynchronizer, MembershipLDAPSynchronizer, ] +CACHE_KEY_PREFIX = "goauthentik.io/sources/ldap/page/" @CELERY_APP.task() def ldap_sync_all(): """Sync all sources""" for source in LDAPSource.objects.filter(enabled=True): - for sync_class in SYNC_CLASSES: - ldap_sync.delay(source.pk, class_to_path(sync_class)) + ldap_sync_single(source) + + +@CELERY_APP.task() +def ldap_sync_single(source: LDAPSource): + """Sync a single source""" + task = chain( + # User and group sync can happen at once, they have no dependencies on each other + group( + ldap_sync_paginator(source, UserLDAPSynchronizer) + + ldap_sync_paginator(source, GroupLDAPSynchronizer), + ), + # Membership sync needs to run afterwards + group( + ldap_sync_paginator(source, MembershipLDAPSynchronizer), + ), + ) + task() + + +def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: + """Return a list of task signatures with LDAP pagination data""" + sync_inst: BaseLDAPSynchronizer = sync(source) + signatures = [] + for page in sync_inst.get_objects(): + page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) + cache.set(page_cache_key, page) + page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) + signatures.append(page_sync) + return signatures @CELERY_APP.task( @@ -34,7 +68,7 @@ def ldap_sync_all(): soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), ) -def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str): +def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): """Synchronization of an LDAP Source""" self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours")) try: @@ -43,11 +77,16 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str): # Because the source couldn't be found, we don't have a UID # to set the state with return - sync = path_to_class(sync_class) - self.set_uid(f"{source.slug}:{sync.__name__.replace('LDAPSynchronizer', '').lower()}") + sync: type[BaseLDAPSynchronizer] = path_to_class(sync_class) + uid = page_cache_key.replace(CACHE_KEY_PREFIX, "") + self.set_uid(f"{source.slug}:{sync.name()}:{uid}") try: - sync_inst = sync(source) - count = sync_inst.sync() + sync_inst: BaseLDAPSynchronizer = sync(source) + page = cache.get(page_cache_key) + if not page: + return + cache.touch(page_cache_key) + count = sync_inst.sync(page) messages = sync_inst.messages messages.append(f"Synced {count} objects.") self.set_status( @@ -56,6 +95,7 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str): messages, ) ) + cache.delete(page_cache_key) except LDAPException as exc: # No explicit event is created here as .set_status with an error will do that LOGGER.warning(exception_to_string(exc)) diff --git a/authentik/sources/ldap/tests/test_auth.py b/authentik/sources/ldap/tests/test_auth.py index 715edc139..d764715b1 100644 --- a/authentik/sources/ldap/tests/test_auth.py +++ b/authentik/sources/ldap/tests/test_auth.py @@ -43,7 +43,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=raw_conn) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): user_sync = UserLDAPSynchronizer(self.source) - user_sync.sync() + user_sync.sync_full() user = User.objects.get(username="user0_sn") # auth_user_by_bind = Mock(return_value=user) @@ -71,7 +71,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): user_sync = UserLDAPSynchronizer(self.source) - user_sync.sync() + user_sync.sync_full() user = User.objects.get(username="user0_sn") auth_user_by_bind = Mock(return_value=user) @@ -98,7 +98,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): user_sync = UserLDAPSynchronizer(self.source) - user_sync.sync() + user_sync.sync_full() user = User.objects.get(username="user0_sn") auth_user_by_bind = Mock(return_value=user) diff --git a/authentik/sources/ldap/tests/test_sync.py b/authentik/sources/ldap/tests/test_sync.py index c5fd09fc4..5382b27c6 100644 --- a/authentik/sources/ldap/tests/test_sync.py +++ b/authentik/sources/ldap/tests/test_sync.py @@ -51,7 +51,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): user_sync = UserLDAPSynchronizer(self.source) - user_sync.sync() + user_sync.sync_full() self.assertFalse(User.objects.filter(username="user0_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists()) events = Event.objects.filter( @@ -87,7 +87,7 @@ class LDAPSyncTests(TestCase): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): user_sync = UserLDAPSynchronizer(self.source) - user_sync.sync() + user_sync.sync_full() user = User.objects.filter(username="user0_sn").first() self.assertEqual(user.attributes["foo"], "bar") self.assertFalse(user.is_active) @@ -106,7 +106,7 @@ class LDAPSyncTests(TestCase): connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): user_sync = UserLDAPSynchronizer(self.source) - user_sync.sync() + user_sync.sync_full() self.assertTrue(User.objects.filter(username="user0_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists()) @@ -128,9 +128,9 @@ class LDAPSyncTests(TestCase): self.source.sync_parent_group = parent_group self.source.save() group_sync = GroupLDAPSynchronizer(self.source) - group_sync.sync() + group_sync.sync_full() membership_sync = MembershipLDAPSynchronizer(self.source) - membership_sync.sync() + membership_sync.sync_full() group: Group = Group.objects.filter(name="test-group").first() self.assertIsNotNone(group) self.assertEqual(group.parent, parent_group) @@ -152,9 +152,9 @@ class LDAPSyncTests(TestCase): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): self.source.save() group_sync = GroupLDAPSynchronizer(self.source) - group_sync.sync() + group_sync.sync_full() membership_sync = MembershipLDAPSynchronizer(self.source) - membership_sync.sync() + membership_sync.sync_full() group = Group.objects.filter(name="group1") self.assertTrue(group.exists()) @@ -177,11 +177,11 @@ class LDAPSyncTests(TestCase): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): self.source.save() user_sync = UserLDAPSynchronizer(self.source) - user_sync.sync() + user_sync.sync_full() group_sync = GroupLDAPSynchronizer(self.source) - group_sync.sync() + group_sync.sync_full() membership_sync = MembershipLDAPSynchronizer(self.source) - membership_sync.sync() + membership_sync.sync_full() # Test if membership mapping based on memberUid works. posix_group = Group.objects.filter(name="group-posix").first() self.assertTrue(posix_group.users.filter(name="user-posix").exists()) diff --git a/tests/e2e/test_source_ldap_samba.py b/tests/e2e/test_source_ldap_samba.py index 3b3d7f9bf..771138912 100644 --- a/tests/e2e/test_source_ldap_samba.py +++ b/tests/e2e/test_source_ldap_samba.py @@ -63,7 +63,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): source.property_mappings_group.set( LDAPPropertyMapping.objects.filter(name="goauthentik.io/sources/ldap/default-name") ) - UserLDAPSynchronizer(source).sync() + UserLDAPSynchronizer(source).sync_full() self.assertTrue(User.objects.filter(username="bob").exists()) self.assertTrue(User.objects.filter(username="james").exists()) self.assertTrue(User.objects.filter(username="john").exists()) @@ -94,9 +94,9 @@ class TestSourceLDAPSamba(SeleniumTestCase): source.property_mappings_group.set( LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/default-name") ) - GroupLDAPSynchronizer(source).sync() - UserLDAPSynchronizer(source).sync() - MembershipLDAPSynchronizer(source).sync() + GroupLDAPSynchronizer(source).sync_full() + UserLDAPSynchronizer(source).sync_full() + MembershipLDAPSynchronizer(source).sync_full() self.assertIsNotNone(User.objects.get(username="bob")) self.assertIsNotNone(User.objects.get(username="james")) self.assertIsNotNone(User.objects.get(username="john")) @@ -137,7 +137,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): source.property_mappings_group.set( LDAPPropertyMapping.objects.filter(name="goauthentik.io/sources/ldap/default-name") ) - UserLDAPSynchronizer(source).sync() + UserLDAPSynchronizer(source).sync_full() username = "bob" password = generate_id() result = self.container.exec_run( @@ -160,7 +160,7 @@ class TestSourceLDAPSamba(SeleniumTestCase): ) self.assertEqual(result.exit_code, 0) # Sync again - UserLDAPSynchronizer(source).sync() + UserLDAPSynchronizer(source).sync_full() user.refresh_from_db() # Since password in samba was checked, it should be invalidated here too self.assertFalse(user.has_usable_password())