flows: revert default flow user change

closes #2483

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-03-14 21:59:27 +01:00
parent e03dd70f2f
commit dcaa8d6322
5 changed files with 86 additions and 39 deletions

View file

@ -124,8 +124,6 @@ class FlowPlanner:
) -> FlowPlan: ) -> FlowPlan:
"""Check each of the flows' policies, check policies for each stage with PolicyBinding """Check each of the flows' policies, check policies for each stage with PolicyBinding
and return ordered list""" and return ordered list"""
if not default_context:
default_context = {}
with Hub.current.start_span( with Hub.current.start_span(
op="authentik.flow.planner.plan", description=self.flow.slug op="authentik.flow.planner.plan", description=self.flow.slug
) as span: ) as span:
@ -139,14 +137,16 @@ class FlowPlanner:
# Bit of a workaround here, if there is a pending user set in the default context # Bit of a workaround here, if there is a pending user set in the default context
# we use that user for our cache key # we use that user for our cache key
# to make sure they don't get the generic response # to make sure they don't get the generic response
if PLAN_CONTEXT_PENDING_USER not in default_context: if default_context and PLAN_CONTEXT_PENDING_USER in default_context:
default_context[PLAN_CONTEXT_PENDING_USER] = request.user user = default_context[PLAN_CONTEXT_PENDING_USER]
user = default_context[PLAN_CONTEXT_PENDING_USER] else:
user = request.user
# First off, check the flow's direct policy bindings # First off, check the flow's direct policy bindings
# to make sure the user even has access to the flow # to make sure the user even has access to the flow
engine = PolicyEngine(self.flow, user, request) engine = PolicyEngine(self.flow, user, request)
span.set_data("default_context", cleanse_dict(default_context)) if default_context:
engine.request.context = default_context span.set_data("default_context", cleanse_dict(default_context))
engine.request.context = default_context
engine.build() engine.build()
result = engine.result result = engine.result
if not result.passing: if not result.passing:

View file

@ -1,5 +1,4 @@
"""Authenticator Validation""" """Authenticator Validation"""
from django.contrib.auth.models import AnonymousUser
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django_otp import devices_for_user from django_otp import devices_for_user
from rest_framework.fields import CharField, IntegerField, JSONField, ListField, UUIDField from rest_framework.fields import CharField, IntegerField, JSONField, ListField, UUIDField
@ -280,7 +279,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
def challenge_valid(self, response: AuthenticatorValidationChallengeResponse) -> HttpResponse: def challenge_valid(self, response: AuthenticatorValidationChallengeResponse) -> HttpResponse:
# All validation is done by the serializer # All validation is done by the serializer
user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER) user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
if not user or isinstance(user, AnonymousUser): if not user:
webauthn_device: WebAuthnDevice = response.data.get("webauthn", None) webauthn_device: WebAuthnDevice = response.data.get("webauthn", None)
if not webauthn_device: if not webauthn_device:
return self.executor.stage_ok() return self.executor.stage_ok()

View file

@ -3,7 +3,6 @@ from typing import Any
from django.contrib import messages from django.contrib import messages
from django.contrib.auth import update_session_auth_hash from django.contrib.auth import update_session_auth_hash
from django.contrib.auth.models import AnonymousUser
from django.db import transaction from django.db import transaction
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
@ -26,15 +25,16 @@ LOGGER = get_logger()
class UserWriteStageView(StageView): class UserWriteStageView(StageView):
"""Finalise Enrollment flow by creating a user object.""" """Finalise Enrollment flow by creating a user object."""
def write_attribute(self, user: User, key: str, value: Any): @staticmethod
def write_attribute(user: User, key: str, value: Any):
"""Allow use of attributes.foo.bar when writing to a user, with full """Allow use of attributes.foo.bar when writing to a user, with full
recursion""" recursion"""
parts = key.replace("_", ".").split(".") parts = key.replace("_", ".").split(".")
if len(parts) < 1: # pragma: no cover if len(parts) < 1: # pragma: no cover
return return
# Function will always be called with a key like attribute. # Function will always be called with a key like attributes.
# this is just a sanity check to ensure that is removed # this is just a sanity check to ensure that is removed
if parts[0] == "attribute": if parts[0] == "attributes":
parts = parts[1:] parts = parts[1:]
attrs = user.attributes attrs = user.attributes
for comp in parts[:-1]: for comp in parts[:-1]:
@ -57,12 +57,7 @@ class UserWriteStageView(StageView):
return self.executor.stage_invalid() return self.executor.stage_invalid()
data = self.executor.plan.context[PLAN_CONTEXT_PROMPT] data = self.executor.plan.context[PLAN_CONTEXT_PROMPT]
user_created = False user_created = False
# check if pending user is set (default to anonymous user), if if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:
# it's an anonymous user then we need to create a new user.
if isinstance(
self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER, AnonymousUser()),
AnonymousUser,
):
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = User( self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = User(
is_active=not self.executor.current_stage.create_users_as_inactive is_active=not self.executor.current_stage.create_users_as_inactive
) )
@ -90,16 +85,20 @@ class UserWriteStageView(StageView):
setter = getattr(user, setter_name) setter = getattr(user, setter_name)
if callable(setter): if callable(setter):
setter(value) setter(value)
# For exact attributes match, update the dictionary in place
elif key == "attributes":
user.attributes.update(value)
# User has this key already # User has this key already
elif hasattr(user, key): elif hasattr(user, key) and not key.startswith("attributes."):
setattr(user, key, value) setattr(user, key, value)
# Otherwise we just save it as custom attribute, but only if the value is prefixed with # Otherwise we just save it as custom attribute, but only if the value is prefixed with
# `attribute_`, to prevent accidentally saving values # `attribute_`, to prevent accidentally saving values
else: else:
if not key.startswith("attribute.") and not key.startswith("attribute_"): if not key.startswith("attributes.") and not key.startswith("attributes_"):
LOGGER.debug("discarding key", key=key) LOGGER.debug("discarding key", key=key)
continue continue
self.write_attribute(user, key, value) UserWriteStageView.write_attribute(user, key, value)
print(user.attributes)
# Extra check to prevent flows from saving a user with a blank username # Extra check to prevent flows from saving a user with a blank username
if user.username == "": if user.username == "":
LOGGER.warning("Aborting write to empty username", user=user) LOGGER.warning("Aborting write to empty username", user=user)

View file

@ -16,6 +16,7 @@ from authentik.flows.tests.test_executor import TO_STAGE_RESPONSE_MOCK
from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
from authentik.stages.user_write.models import UserWriteStage from authentik.stages.user_write.models import UserWriteStage
from authentik.stages.user_write.stage import UserWriteStageView
class TestUserWriteStage(FlowTestCase): class TestUserWriteStage(FlowTestCase):
@ -77,7 +78,7 @@ class TestUserWriteStage(FlowTestCase):
plan.context[PLAN_CONTEXT_PROMPT] = { plan.context[PLAN_CONTEXT_PROMPT] = {
"username": "test-user-new", "username": "test-user-new",
"password": new_password, "password": new_password,
"attribute.some.custom-attribute": "test", "attributes.some.custom-attribute": "test",
"some_ignored_attribute": "bar", "some_ignored_attribute": "bar",
} }
session = self.client.session session = self.client.session
@ -172,3 +173,43 @@ class TestUserWriteStage(FlowTestCase):
self.flow, self.flow,
component="ak-stage-access-denied", component="ak-stage-access-denied",
) )
def test_write_attribute(self):
"""Test write_attribute"""
user = create_test_admin_user()
user.attributes = {
"foo": "bar",
"baz": {
"qwer": [
"quox",
]
},
}
user.save()
UserWriteStageView.write_attribute(user, "attributes.foo", "baz")
self.assertEqual(
user.attributes,
{
"foo": "baz",
"baz": {
"qwer": [
"quox",
]
},
},
)
UserWriteStageView.write_attribute(user, "attributes.foob.bar", "baz")
self.assertEqual(
user.attributes,
{
"foo": "baz",
"foob": {
"bar": "baz",
},
"baz": {
"qwer": [
"quox",
]
},
},
)

View file

@ -101,25 +101,33 @@ export class UserSettingsFlowExecutor extends LitElement implements StageHost {
if (!this.flowSlug) { if (!this.flowSlug) {
return; return;
} }
this.loading = true; new FlowsApi(DEFAULT_CONFIG).flowsInstancesExecuteRetrieve({
new FlowsApi(DEFAULT_CONFIG) slug: this.flowSlug || "",
.flowsExecutorGet({ }).then(() => {
flowSlug: this.flowSlug, this.nextChallenge();
query: window.location.search.substring(1), })
})
.then((challenge) => {
this.challenge = challenge;
})
.catch((e: Error | Response) => {
// Catch JSON or Update errors
this.errorMessage(e);
})
.finally(() => {
this.loading = false;
});
}); });
} }
nextChallenge(): void {
this.loading = true;
new FlowsApi(DEFAULT_CONFIG)
.flowsExecutorGet({
flowSlug: this.flowSlug || "",
query: window.location.search.substring(1),
})
.then((challenge) => {
this.challenge = challenge;
})
.catch((e: Error | Response) => {
// Catch JSON or Update errors
this.errorMessage(e);
})
.finally(() => {
this.loading = false;
});
}
async errorMessage(error: Error | Response): Promise<void> { async errorMessage(error: Error | Response): Promise<void> {
let body = ""; let body = "";
if (error instanceof Error) { if (error instanceof Error) {