From bb8b87fcb331b2c68c4aab1d4bfe07c9d36ddf25 Mon Sep 17 00:00:00 2001 From: Jens L Date: Sun, 30 Apr 2023 19:43:24 +0300 Subject: [PATCH] providers/scim: improve compatibility (#5425) * providers/scim: improve compatibility Signed-off-by: Jens Langhammer * fix lint and tests Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer --- authentik/providers/scim/clients/base.py | 23 +-------- authentik/providers/scim/clients/group.py | 48 ++++++++++++++----- authentik/providers/scim/clients/schema.py | 45 +++++++++++++++-- .../providers/scim/management/__init__.py | 0 .../scim/management/commands/__init__.py | 0 .../scim/management/commands/scim_sync.py | 23 +++++++++ authentik/providers/scim/tasks.py | 6 ++- .../providers/scim/tests/test_membership.py | 15 +++--- 8 files changed, 116 insertions(+), 44 deletions(-) create mode 100644 authentik/providers/scim/management/__init__.py create mode 100644 authentik/providers/scim/management/commands/__init__.py create mode 100644 authentik/providers/scim/management/commands/scim_sync.py diff --git a/authentik/providers/scim/clients/base.py b/authentik/providers/scim/clients/base.py index 8399b862f..c841ecfa0 100644 --- a/authentik/providers/scim/clients/base.py +++ b/authentik/providers/scim/clients/base.py @@ -2,19 +2,12 @@ from typing import Generic, TypeVar from pydantic import ValidationError -from pydanticscim.service_provider import ( - Bulk, - ChangePassword, - Filter, - Patch, - ServiceProviderConfiguration, - Sort, -) from requests import RequestException, Session from structlog.stdlib import get_logger from authentik.lib.utils.http import get_http_session from authentik.providers.scim.clients.exceptions import ResourceMissing, SCIMRequestException +from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.models import SCIMProvider T = TypeVar("T") @@ -22,18 +15,6 @@ T = TypeVar("T") SchemaType = TypeVar("SchemaType") -def default_service_provider_config() -> ServiceProviderConfiguration: - """Fallback service provider configuration""" - return ServiceProviderConfiguration( - patch=Patch(supported=False), - bulk=Bulk(supported=False), - filter=Filter(supported=False), - changePassword=ChangePassword(supported=False), - sort=Sort(supported=False), - authenticationSchemes=[], - ) - - class SCIMClient(Generic[T, SchemaType]): """SCIM Client""" @@ -85,7 +66,7 @@ class SCIMClient(Generic[T, SchemaType]): def get_service_provider_config(self): """Get Service provider config""" - default_config = default_service_provider_config() + default_config = ServiceProviderConfiguration.default() try: return ServiceProviderConfiguration.parse_obj( self._request("GET", "/ServiceProviderConfig") diff --git a/authentik/providers/scim/clients/group.py b/authentik/providers/scim/clients/group.py index 656d9edd6..716e89985 100644 --- a/authentik/providers/scim/clients/group.py +++ b/authentik/providers/scim/clients/group.py @@ -2,7 +2,7 @@ from deepmerge import always_merger from pydantic import ValidationError from pydanticscim.group import GroupMember -from pydanticscim.responses import PatchOp, PatchOperation, PatchRequest +from pydanticscim.responses import PatchOp, PatchOperation from authentik.core.exceptions import PropertyMappingExpressionException from authentik.core.models import Group @@ -10,8 +10,13 @@ from authentik.events.models import Event, EventAction from authentik.lib.utils.errors import exception_to_string from authentik.policies.utils import delete_none_keys from authentik.providers.scim.clients.base import SCIMClient -from authentik.providers.scim.clients.exceptions import ResourceMissing, StopSync +from authentik.providers.scim.clients.exceptions import ( + ResourceMissing, + SCIMRequestException, + StopSync, +) from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema +from authentik.providers.scim.clients.schema import PatchRequest from authentik.providers.scim.models import SCIMGroup, SCIMMapping, SCIMUser @@ -104,13 +109,20 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): """Update existing group""" scim_group = self.to_scim(group) scim_group.id = connection.id - return self._request( - "PUT", - f"/Groups/{scim_group.id}", - data=scim_group.json( - exclude_unset=True, - ), - ) + try: + return self._request( + "PUT", + f"/Groups/{scim_group.id}", + data=scim_group.json( + exclude_unset=True, + ), + ) + except SCIMRequestException: + # Some providers don't support PUT on groups, so this is mainly a fix for the initial + # sync, send patch add requests for all the users the group currently has + # TODO: send patch request for group name + users = list(group.users.order_by("id").values_list("id", flat=True)) + return self._patch_add_users(group, users) def _patch( self, @@ -118,7 +130,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): *ops: PatchOperation, ): req = PatchRequest(Operations=ops) - self._request("PATCH", f"/Groups/{group_id}", data=req.json(exclude_unset=True)) + self._request("PATCH", f"/Groups/{group_id}", data=req.json()) def update_group(self, group: Group, action: PatchOp, users_set: set[int]): """Update a group, either using PUT to replace it or PATCH if supported""" @@ -127,7 +139,17 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): return self._patch_add_users(group, users_set) if action == PatchOp.remove: return self._patch_remove_users(group, users_set) - return self.write(group) + try: + return self.write(group) + except SCIMRequestException as exc: + if self._config.is_fallback: + # Assume that provider does not support PUT and also doesn't support + # ServiceProviderConfig, so try PATCH as a fallback + if action == PatchOp.add: + return self._patch_add_users(group, users_set) + if action == PatchOp.remove: + return self._patch_remove_users(group, users_set) + raise exc def _patch_add_users(self, group: Group, users_set: set[int]): """Add users in users_set to group""" @@ -144,6 +166,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): "id", flat=True ) ) + if len(user_ids) < 1: + return self._patch( scim_group.id, PatchOperation( @@ -168,6 +192,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): "id", flat=True ) ) + if len(user_ids) < 1: + return self._patch( scim_group.id, PatchOperation( diff --git a/authentik/providers/scim/clients/schema.py b/authentik/providers/scim/clients/schema.py index f4c94157a..4236e7524 100644 --- a/authentik/providers/scim/clients/schema.py +++ b/authentik/providers/scim/clients/schema.py @@ -1,17 +1,54 @@ """Custom SCIM schemas""" from typing import Optional -from pydanticscim.group import Group as SCIMGroupSchema -from pydanticscim.user import User as SCIMUserSchema +from pydanticscim.group import Group as BaseGroup +from pydanticscim.responses import PatchRequest as BasePatchRequest +from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch +from pydanticscim.service_provider import ( + ServiceProviderConfiguration as BaseServiceProviderConfiguration, +) +from pydanticscim.service_provider import Sort +from pydanticscim.user import User as BaseUser -class User(SCIMUserSchema): +class User(BaseUser): """Modified User schema with added externalId field""" externalId: Optional[str] = None -class Group(SCIMGroupSchema): +class Group(BaseGroup): """Modified Group schema with added externalId field""" externalId: Optional[str] = None + + +class ServiceProviderConfiguration(BaseServiceProviderConfiguration): + """ServiceProviderConfig with fallback""" + + _is_fallback: Optional[bool] = False + + @property + def is_fallback(self) -> bool: + """Check if this service provider config was retrieved from the API endpoint + or a fallback was used""" + return self._is_fallback + + @staticmethod + def default() -> "ServiceProviderConfiguration": + """Get default configuration, which doesn't support any optional features as fallback""" + return ServiceProviderConfiguration( + patch=Patch(supported=False), + bulk=Bulk(supported=False), + filter=Filter(supported=False), + changePassword=ChangePassword(supported=False), + sort=Sort(supported=False), + authenticationSchemes=[], + _is_fallback=True, + ) + + +class PatchRequest(BasePatchRequest): + """PatchRequest which correctly sets schemas""" + + schemas: tuple[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"] diff --git a/authentik/providers/scim/management/__init__.py b/authentik/providers/scim/management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/authentik/providers/scim/management/commands/__init__.py b/authentik/providers/scim/management/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/authentik/providers/scim/management/commands/scim_sync.py b/authentik/providers/scim/management/commands/scim_sync.py new file mode 100644 index 000000000..0fbca3dce --- /dev/null +++ b/authentik/providers/scim/management/commands/scim_sync.py @@ -0,0 +1,23 @@ +"""SCIM Sync""" +from django.core.management.base import BaseCommand +from structlog.stdlib import get_logger + +from authentik.providers.scim.models import SCIMProvider +from authentik.providers.scim.tasks import scim_sync + +LOGGER = get_logger() + + +class Command(BaseCommand): + """Run sync for an SCIM Provider""" + + def add_arguments(self, parser): + parser.add_argument("providers", nargs="+", type=str) + + def handle(self, **options): + for provider_name in options["providers"]: + provider = SCIMProvider.objects.filter(name=provider_name).first() + if not provider: + LOGGER.warning("Provider does not exist", name=provider_name) + continue + scim_sync.delay(provider.pk).get() diff --git a/authentik/providers/scim/tasks.py b/authentik/providers/scim/tasks.py index 18673d3ad..fd2bea645 100644 --- a/authentik/providers/scim/tasks.py +++ b/authentik/providers/scim/tasks.py @@ -94,7 +94,8 @@ def scim_sync_users(page: int, provider_pk: int): } ) ) - except StopSync: + except StopSync as exc: + LOGGER.warning("Stopping sync", exc=exc) break return messages @@ -126,7 +127,8 @@ def scim_sync_group(page: int, provider_pk: int): } ) ) - except StopSync: + except StopSync as exc: + LOGGER.warning("Stopping sync", exc=exc) break return messages diff --git a/authentik/providers/scim/tests/test_membership.py b/authentik/providers/scim/tests/test_membership.py index 73754a569..de0844450 100644 --- a/authentik/providers/scim/tests/test_membership.py +++ b/authentik/providers/scim/tests/test_membership.py @@ -6,7 +6,7 @@ from requests_mock import Mocker from authentik.blueprints.tests import apply_blueprint from authentik.core.models import Group, User from authentik.lib.generators import generate_id -from authentik.providers.scim.clients.base import default_service_provider_config +from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.models import SCIMMapping, SCIMProvider from authentik.providers.scim.tasks import scim_sync @@ -39,7 +39,7 @@ class SCIMMembershipTests(TestCase): def test_member_add(self): """Test member add""" - config = default_service_provider_config() + config = ServiceProviderConfiguration.default() config.patch.supported = True user_scim_id = generate_id() group_scim_id = generate_id() @@ -117,13 +117,14 @@ class SCIMMembershipTests(TestCase): "path": "members", "value": [{"value": user_scim_id}], } - ] + ], + "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], }, ) def test_member_remove(self): """Test member remove""" - config = default_service_provider_config() + config = ServiceProviderConfiguration.default() config.patch.supported = True user_scim_id = generate_id() group_scim_id = generate_id() @@ -201,7 +202,8 @@ class SCIMMembershipTests(TestCase): "path": "members", "value": [{"value": user_scim_id}], } - ] + ], + "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], }, ) @@ -227,6 +229,7 @@ class SCIMMembershipTests(TestCase): "path": "members", "value": [{"value": user_scim_id}], } - ] + ], + "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], }, )