stages/user_write: allow recursive writing to user.attributes

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-10-07 18:57:19 +02:00
parent 57e5acaf2f
commit b4ee693a5c
2 changed files with 22 additions and 4 deletions

View File

@ -1,4 +1,5 @@
"""Write stage logic""" """Write stage logic"""
from typing import Any
from django.contrib import messages from django.contrib import messages
from django.contrib.auth import update_session_auth_hash from django.contrib.auth import update_session_auth_hash
from django.db import transaction from django.db import transaction
@ -23,6 +24,23 @@ LOGGER = get_logger()
class UserWriteStageView(StageView): class UserWriteStageView(StageView):
"""Finalise Enrollment flow by creating a user object.""" """Finalise Enrollment flow by creating a user object."""
def write_attribute(self, user: User, key: str, value: Any):
"""Allow use of attributes.foo.bar when writing to a user, with full
recursion"""
parts = key.replace("_", ".").split(".")
if len(parts) < 1:
return
# Function will always be called with a key like attribute.
# this is just a sanity check to ensure that is removed
if parts[0] == "attribute":
parts = parts[1:]
attrs = user.attributes
for comp in parts[:-1]:
if comp not in attrs:
attrs[comp] = {}
attrs = attrs.get(comp)
attrs[parts[-1]] = value
def post(self, request: HttpRequest) -> HttpResponse: def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests""" """Wrapper for post requests"""
return self.get(request) return self.get(request)
@ -71,10 +89,10 @@ class UserWriteStageView(StageView):
# Otherwise we just save it as custom attribute, but only if the value is prefixed with # Otherwise we just save it as custom attribute, but only if the value is prefixed with
# `attribute_`, to prevent accidentally saving values # `attribute_`, to prevent accidentally saving values
else: else:
if not key.startswith("attribute_"): if not key.startswith("attribute.") and not key.startswith("attribute_"):
LOGGER.debug("discarding key", key=key) LOGGER.debug("discarding key", key=key)
continue continue
user.attributes[key.replace("attribute_", "", 1)] = value self.write_attribute(user, key, value)
# Extra check to prevent flows from saving a user with a blank username # Extra check to prevent flows from saving a user with a blank username
if user.username == "": if user.username == "":
LOGGER.warning("Aborting write to empty username", user=user) LOGGER.warning("Aborting write to empty username", user=user)

View File

@ -85,7 +85,7 @@ class TestUserWriteStage(APITestCase):
plan.context[PLAN_CONTEXT_PROMPT] = { plan.context[PLAN_CONTEXT_PROMPT] = {
"username": "test-user-new", "username": "test-user-new",
"password": new_password, "password": new_password,
"attribute_some-custom-attribute": "test", "attribute.some.custom-attribute": "test",
"some_ignored_attribute": "bar", "some_ignored_attribute": "bar",
} }
session = self.client.session session = self.client.session
@ -108,7 +108,7 @@ class TestUserWriteStage(APITestCase):
user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"]) user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"])
self.assertTrue(user_qs.exists()) self.assertTrue(user_qs.exists())
self.assertTrue(user_qs.first().check_password(new_password)) self.assertTrue(user_qs.first().check_password(new_password))
self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test") self.assertEqual(user_qs.first().attributes["some"]["custom-attribute"], "test")
self.assertNotIn("some_ignored_attribute", user_qs.first().attributes) self.assertNotIn("some_ignored_attribute", user_qs.first().attributes)
@patch( @patch(