stages/user_write: allow recursive writing to user.attributes
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
57e5acaf2f
commit
b4ee693a5c
|
@ -1,4 +1,5 @@
|
|||
"""Write stage logic"""
|
||||
from typing import Any
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth import update_session_auth_hash
|
||||
from django.db import transaction
|
||||
|
@ -23,6 +24,23 @@ LOGGER = get_logger()
|
|||
class UserWriteStageView(StageView):
|
||||
"""Finalise Enrollment flow by creating a user object."""
|
||||
|
||||
def write_attribute(self, user: User, key: str, value: Any):
|
||||
"""Allow use of attributes.foo.bar when writing to a user, with full
|
||||
recursion"""
|
||||
parts = key.replace("_", ".").split(".")
|
||||
if len(parts) < 1:
|
||||
return
|
||||
# Function will always be called with a key like attribute.
|
||||
# this is just a sanity check to ensure that is removed
|
||||
if parts[0] == "attribute":
|
||||
parts = parts[1:]
|
||||
attrs = user.attributes
|
||||
for comp in parts[:-1]:
|
||||
if comp not in attrs:
|
||||
attrs[comp] = {}
|
||||
attrs = attrs.get(comp)
|
||||
attrs[parts[-1]] = value
|
||||
|
||||
def post(self, request: HttpRequest) -> HttpResponse:
|
||||
"""Wrapper for post requests"""
|
||||
return self.get(request)
|
||||
|
@ -71,10 +89,10 @@ class UserWriteStageView(StageView):
|
|||
# Otherwise we just save it as custom attribute, but only if the value is prefixed with
|
||||
# `attribute_`, to prevent accidentally saving values
|
||||
else:
|
||||
if not key.startswith("attribute_"):
|
||||
if not key.startswith("attribute.") and not key.startswith("attribute_"):
|
||||
LOGGER.debug("discarding key", key=key)
|
||||
continue
|
||||
user.attributes[key.replace("attribute_", "", 1)] = value
|
||||
self.write_attribute(user, key, value)
|
||||
# Extra check to prevent flows from saving a user with a blank username
|
||||
if user.username == "":
|
||||
LOGGER.warning("Aborting write to empty username", user=user)
|
||||
|
|
|
@ -85,7 +85,7 @@ class TestUserWriteStage(APITestCase):
|
|||
plan.context[PLAN_CONTEXT_PROMPT] = {
|
||||
"username": "test-user-new",
|
||||
"password": new_password,
|
||||
"attribute_some-custom-attribute": "test",
|
||||
"attribute.some.custom-attribute": "test",
|
||||
"some_ignored_attribute": "bar",
|
||||
}
|
||||
session = self.client.session
|
||||
|
@ -108,7 +108,7 @@ class TestUserWriteStage(APITestCase):
|
|||
user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"])
|
||||
self.assertTrue(user_qs.exists())
|
||||
self.assertTrue(user_qs.first().check_password(new_password))
|
||||
self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test")
|
||||
self.assertEqual(user_qs.first().attributes["some"]["custom-attribute"], "test")
|
||||
self.assertNotIn("some_ignored_attribute", user_qs.first().attributes)
|
||||
|
||||
@patch(
|
||||
|
|
Reference in a new issue