stages/user_write: allow recursive writing to user.attributes
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
57e5acaf2f
commit
b4ee693a5c
|
@ -1,4 +1,5 @@
|
||||||
"""Write stage logic"""
|
"""Write stage logic"""
|
||||||
|
from typing import Any
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.contrib.auth import update_session_auth_hash
|
from django.contrib.auth import update_session_auth_hash
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
|
@ -23,6 +24,23 @@ LOGGER = get_logger()
|
||||||
class UserWriteStageView(StageView):
|
class UserWriteStageView(StageView):
|
||||||
"""Finalise Enrollment flow by creating a user object."""
|
"""Finalise Enrollment flow by creating a user object."""
|
||||||
|
|
||||||
|
def write_attribute(self, user: User, key: str, value: Any):
|
||||||
|
"""Allow use of attributes.foo.bar when writing to a user, with full
|
||||||
|
recursion"""
|
||||||
|
parts = key.replace("_", ".").split(".")
|
||||||
|
if len(parts) < 1:
|
||||||
|
return
|
||||||
|
# Function will always be called with a key like attribute.
|
||||||
|
# this is just a sanity check to ensure that is removed
|
||||||
|
if parts[0] == "attribute":
|
||||||
|
parts = parts[1:]
|
||||||
|
attrs = user.attributes
|
||||||
|
for comp in parts[:-1]:
|
||||||
|
if comp not in attrs:
|
||||||
|
attrs[comp] = {}
|
||||||
|
attrs = attrs.get(comp)
|
||||||
|
attrs[parts[-1]] = value
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
def post(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Wrapper for post requests"""
|
"""Wrapper for post requests"""
|
||||||
return self.get(request)
|
return self.get(request)
|
||||||
|
@ -71,10 +89,10 @@ class UserWriteStageView(StageView):
|
||||||
# Otherwise we just save it as custom attribute, but only if the value is prefixed with
|
# Otherwise we just save it as custom attribute, but only if the value is prefixed with
|
||||||
# `attribute_`, to prevent accidentally saving values
|
# `attribute_`, to prevent accidentally saving values
|
||||||
else:
|
else:
|
||||||
if not key.startswith("attribute_"):
|
if not key.startswith("attribute.") and not key.startswith("attribute_"):
|
||||||
LOGGER.debug("discarding key", key=key)
|
LOGGER.debug("discarding key", key=key)
|
||||||
continue
|
continue
|
||||||
user.attributes[key.replace("attribute_", "", 1)] = value
|
self.write_attribute(user, key, value)
|
||||||
# Extra check to prevent flows from saving a user with a blank username
|
# Extra check to prevent flows from saving a user with a blank username
|
||||||
if user.username == "":
|
if user.username == "":
|
||||||
LOGGER.warning("Aborting write to empty username", user=user)
|
LOGGER.warning("Aborting write to empty username", user=user)
|
||||||
|
|
|
@ -85,7 +85,7 @@ class TestUserWriteStage(APITestCase):
|
||||||
plan.context[PLAN_CONTEXT_PROMPT] = {
|
plan.context[PLAN_CONTEXT_PROMPT] = {
|
||||||
"username": "test-user-new",
|
"username": "test-user-new",
|
||||||
"password": new_password,
|
"password": new_password,
|
||||||
"attribute_some-custom-attribute": "test",
|
"attribute.some.custom-attribute": "test",
|
||||||
"some_ignored_attribute": "bar",
|
"some_ignored_attribute": "bar",
|
||||||
}
|
}
|
||||||
session = self.client.session
|
session = self.client.session
|
||||||
|
@ -108,7 +108,7 @@ class TestUserWriteStage(APITestCase):
|
||||||
user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"])
|
user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"])
|
||||||
self.assertTrue(user_qs.exists())
|
self.assertTrue(user_qs.exists())
|
||||||
self.assertTrue(user_qs.first().check_password(new_password))
|
self.assertTrue(user_qs.first().check_password(new_password))
|
||||||
self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test")
|
self.assertEqual(user_qs.first().attributes["some"]["custom-attribute"], "test")
|
||||||
self.assertNotIn("some_ignored_attribute", user_qs.first().attributes)
|
self.assertNotIn("some_ignored_attribute", user_qs.first().attributes)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
|
|
Reference in New Issue