diff --git a/authentik/stages/user_write/stage.py b/authentik/stages/user_write/stage.py index 30abfe25a..89fb5bab5 100644 --- a/authentik/stages/user_write/stage.py +++ b/authentik/stages/user_write/stage.py @@ -1,4 +1,5 @@ """Write stage logic""" +from typing import Any from django.contrib import messages from django.contrib.auth import update_session_auth_hash from django.db import transaction @@ -23,6 +24,23 @@ LOGGER = get_logger() class UserWriteStageView(StageView): """Finalise Enrollment flow by creating a user object.""" + def write_attribute(self, user: User, key: str, value: Any): + """Allow use of attributes.foo.bar when writing to a user, with full + recursion""" + parts = key.replace("_", ".").split(".") + if len(parts) < 1: + return + # Function will always be called with a key like attribute. + # this is just a sanity check to ensure that is removed + if parts[0] == "attribute": + parts = parts[1:] + attrs = user.attributes + for comp in parts[:-1]: + if comp not in attrs: + attrs[comp] = {} + attrs = attrs.get(comp) + attrs[parts[-1]] = value + def post(self, request: HttpRequest) -> HttpResponse: """Wrapper for post requests""" return self.get(request) @@ -71,10 +89,10 @@ class UserWriteStageView(StageView): # Otherwise we just save it as custom attribute, but only if the value is prefixed with # `attribute_`, to prevent accidentally saving values else: - if not key.startswith("attribute_"): + if not key.startswith("attribute.") and not key.startswith("attribute_"): LOGGER.debug("discarding key", key=key) continue - user.attributes[key.replace("attribute_", "", 1)] = value + self.write_attribute(user, key, value) # Extra check to prevent flows from saving a user with a blank username if user.username == "": LOGGER.warning("Aborting write to empty username", user=user) diff --git a/authentik/stages/user_write/tests.py b/authentik/stages/user_write/tests.py index eb951be14..986d907b5 100644 --- a/authentik/stages/user_write/tests.py +++ b/authentik/stages/user_write/tests.py @@ -85,7 +85,7 @@ class TestUserWriteStage(APITestCase): plan.context[PLAN_CONTEXT_PROMPT] = { "username": "test-user-new", "password": new_password, - "attribute_some-custom-attribute": "test", + "attribute.some.custom-attribute": "test", "some_ignored_attribute": "bar", } session = self.client.session @@ -108,7 +108,7 @@ class TestUserWriteStage(APITestCase): user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"]) self.assertTrue(user_qs.exists()) self.assertTrue(user_qs.first().check_password(new_password)) - self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test") + self.assertEqual(user_qs.first().attributes["some"]["custom-attribute"], "test") self.assertNotIn("some_ignored_attribute", user_qs.first().attributes) @patch(