providers/scim: improve compatibility (#5425)

* providers/scim: improve compatibility

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix lint and tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-04-30 19:43:24 +03:00 committed by GitHub
parent f36a5a053f
commit bb8b87fcb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 116 additions and 44 deletions

View File

@ -2,19 +2,12 @@
from typing import Generic, TypeVar from typing import Generic, TypeVar
from pydantic import ValidationError from pydantic import ValidationError
from pydanticscim.service_provider import (
Bulk,
ChangePassword,
Filter,
Patch,
ServiceProviderConfiguration,
Sort,
)
from requests import RequestException, Session from requests import RequestException, Session
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.providers.scim.clients.exceptions import ResourceMissing, SCIMRequestException from authentik.providers.scim.clients.exceptions import ResourceMissing, SCIMRequestException
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
T = TypeVar("T") T = TypeVar("T")
@ -22,18 +15,6 @@ T = TypeVar("T")
SchemaType = TypeVar("SchemaType") SchemaType = TypeVar("SchemaType")
def default_service_provider_config() -> ServiceProviderConfiguration:
"""Fallback service provider configuration"""
return ServiceProviderConfiguration(
patch=Patch(supported=False),
bulk=Bulk(supported=False),
filter=Filter(supported=False),
changePassword=ChangePassword(supported=False),
sort=Sort(supported=False),
authenticationSchemes=[],
)
class SCIMClient(Generic[T, SchemaType]): class SCIMClient(Generic[T, SchemaType]):
"""SCIM Client""" """SCIM Client"""
@ -85,7 +66,7 @@ class SCIMClient(Generic[T, SchemaType]):
def get_service_provider_config(self): def get_service_provider_config(self):
"""Get Service provider config""" """Get Service provider config"""
default_config = default_service_provider_config() default_config = ServiceProviderConfiguration.default()
try: try:
return ServiceProviderConfiguration.parse_obj( return ServiceProviderConfiguration.parse_obj(
self._request("GET", "/ServiceProviderConfig") self._request("GET", "/ServiceProviderConfig")

View File

@ -2,7 +2,7 @@
from deepmerge import always_merger from deepmerge import always_merger
from pydantic import ValidationError from pydantic import ValidationError
from pydanticscim.group import GroupMember from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp, PatchOperation, PatchRequest from pydanticscim.responses import PatchOp, PatchOperation
from authentik.core.exceptions import PropertyMappingExpressionException from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import Group from authentik.core.models import Group
@ -10,8 +10,13 @@ from authentik.events.models import Event, EventAction
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.policies.utils import delete_none_keys from authentik.policies.utils import delete_none_keys
from authentik.providers.scim.clients.base import SCIMClient from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import ResourceMissing, StopSync from authentik.providers.scim.clients.exceptions import (
ResourceMissing,
SCIMRequestException,
StopSync,
)
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.clients.schema import PatchRequest
from authentik.providers.scim.models import SCIMGroup, SCIMMapping, SCIMUser from authentik.providers.scim.models import SCIMGroup, SCIMMapping, SCIMUser
@ -104,6 +109,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
"""Update existing group""" """Update existing group"""
scim_group = self.to_scim(group) scim_group = self.to_scim(group)
scim_group.id = connection.id scim_group.id = connection.id
try:
return self._request( return self._request(
"PUT", "PUT",
f"/Groups/{scim_group.id}", f"/Groups/{scim_group.id}",
@ -111,6 +117,12 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
exclude_unset=True, exclude_unset=True,
), ),
) )
except SCIMRequestException:
# Some providers don't support PUT on groups, so this is mainly a fix for the initial
# sync, send patch add requests for all the users the group currently has
# TODO: send patch request for group name
users = list(group.users.order_by("id").values_list("id", flat=True))
return self._patch_add_users(group, users)
def _patch( def _patch(
self, self,
@ -118,7 +130,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
*ops: PatchOperation, *ops: PatchOperation,
): ):
req = PatchRequest(Operations=ops) req = PatchRequest(Operations=ops)
self._request("PATCH", f"/Groups/{group_id}", data=req.json(exclude_unset=True)) self._request("PATCH", f"/Groups/{group_id}", data=req.json())
def update_group(self, group: Group, action: PatchOp, users_set: set[int]): def update_group(self, group: Group, action: PatchOp, users_set: set[int]):
"""Update a group, either using PUT to replace it or PATCH if supported""" """Update a group, either using PUT to replace it or PATCH if supported"""
@ -127,7 +139,17 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
return self._patch_add_users(group, users_set) return self._patch_add_users(group, users_set)
if action == PatchOp.remove: if action == PatchOp.remove:
return self._patch_remove_users(group, users_set) return self._patch_remove_users(group, users_set)
try:
return self.write(group) return self.write(group)
except SCIMRequestException as exc:
if self._config.is_fallback:
# Assume that provider does not support PUT and also doesn't support
# ServiceProviderConfig, so try PATCH as a fallback
if action == PatchOp.add:
return self._patch_add_users(group, users_set)
if action == PatchOp.remove:
return self._patch_remove_users(group, users_set)
raise exc
def _patch_add_users(self, group: Group, users_set: set[int]): def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group""" """Add users in users_set to group"""
@ -144,6 +166,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
"id", flat=True "id", flat=True
) )
) )
if len(user_ids) < 1:
return
self._patch( self._patch(
scim_group.id, scim_group.id,
PatchOperation( PatchOperation(
@ -168,6 +192,8 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
"id", flat=True "id", flat=True
) )
) )
if len(user_ids) < 1:
return
self._patch( self._patch(
scim_group.id, scim_group.id,
PatchOperation( PatchOperation(

View File

@ -1,17 +1,54 @@
"""Custom SCIM schemas""" """Custom SCIM schemas"""
from typing import Optional from typing import Optional
from pydanticscim.group import Group as SCIMGroupSchema from pydanticscim.group import Group as BaseGroup
from pydanticscim.user import User as SCIMUserSchema from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch
from pydanticscim.service_provider import (
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
)
from pydanticscim.service_provider import Sort
from pydanticscim.user import User as BaseUser
class User(SCIMUserSchema): class User(BaseUser):
"""Modified User schema with added externalId field""" """Modified User schema with added externalId field"""
externalId: Optional[str] = None externalId: Optional[str] = None
class Group(SCIMGroupSchema): class Group(BaseGroup):
"""Modified Group schema with added externalId field""" """Modified Group schema with added externalId field"""
externalId: Optional[str] = None externalId: Optional[str] = None
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""ServiceProviderConfig with fallback"""
_is_fallback: Optional[bool] = False
@property
def is_fallback(self) -> bool:
"""Check if this service provider config was retrieved from the API endpoint
or a fallback was used"""
return self._is_fallback
@staticmethod
def default() -> "ServiceProviderConfiguration":
"""Get default configuration, which doesn't support any optional features as fallback"""
return ServiceProviderConfiguration(
patch=Patch(supported=False),
bulk=Bulk(supported=False),
filter=Filter(supported=False),
changePassword=ChangePassword(supported=False),
sort=Sort(supported=False),
authenticationSchemes=[],
_is_fallback=True,
)
class PatchRequest(BasePatchRequest):
"""PatchRequest which correctly sets schemas"""
schemas: tuple[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]

View File

@ -0,0 +1,23 @@
"""SCIM Sync"""
from django.core.management.base import BaseCommand
from structlog.stdlib import get_logger
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync
LOGGER = get_logger()
class Command(BaseCommand):
"""Run sync for an SCIM Provider"""
def add_arguments(self, parser):
parser.add_argument("providers", nargs="+", type=str)
def handle(self, **options):
for provider_name in options["providers"]:
provider = SCIMProvider.objects.filter(name=provider_name).first()
if not provider:
LOGGER.warning("Provider does not exist", name=provider_name)
continue
scim_sync.delay(provider.pk).get()

View File

@ -94,7 +94,8 @@ def scim_sync_users(page: int, provider_pk: int):
} }
) )
) )
except StopSync: except StopSync as exc:
LOGGER.warning("Stopping sync", exc=exc)
break break
return messages return messages
@ -126,7 +127,8 @@ def scim_sync_group(page: int, provider_pk: int):
} }
) )
) )
except StopSync: except StopSync as exc:
LOGGER.warning("Stopping sync", exc=exc)
break break
return messages return messages

View File

@ -6,7 +6,7 @@ from requests_mock import Mocker
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.base import default_service_provider_config from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMMapping, SCIMProvider from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync from authentik.providers.scim.tasks import scim_sync
@ -39,7 +39,7 @@ class SCIMMembershipTests(TestCase):
def test_member_add(self): def test_member_add(self):
"""Test member add""" """Test member add"""
config = default_service_provider_config() config = ServiceProviderConfiguration.default()
config.patch.supported = True config.patch.supported = True
user_scim_id = generate_id() user_scim_id = generate_id()
group_scim_id = generate_id() group_scim_id = generate_id()
@ -117,13 +117,14 @@ class SCIMMembershipTests(TestCase):
"path": "members", "path": "members",
"value": [{"value": user_scim_id}], "value": [{"value": user_scim_id}],
} }
] ],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
}, },
) )
def test_member_remove(self): def test_member_remove(self):
"""Test member remove""" """Test member remove"""
config = default_service_provider_config() config = ServiceProviderConfiguration.default()
config.patch.supported = True config.patch.supported = True
user_scim_id = generate_id() user_scim_id = generate_id()
group_scim_id = generate_id() group_scim_id = generate_id()
@ -201,7 +202,8 @@ class SCIMMembershipTests(TestCase):
"path": "members", "path": "members",
"value": [{"value": user_scim_id}], "value": [{"value": user_scim_id}],
} }
] ],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
}, },
) )
@ -227,6 +229,7 @@ class SCIMMembershipTests(TestCase):
"path": "members", "path": "members",
"value": [{"value": user_scim_id}], "value": [{"value": user_scim_id}],
} }
] ],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
}, },
) )