diff --git a/authentik/flows/planner.py b/authentik/flows/planner.py index d1417363f..2f3442914 100644 --- a/authentik/flows/planner.py +++ b/authentik/flows/planner.py @@ -124,8 +124,6 @@ class FlowPlanner: ) -> FlowPlan: """Check each of the flows' policies, check policies for each stage with PolicyBinding and return ordered list""" - if not default_context: - default_context = {} with Hub.current.start_span( op="authentik.flow.planner.plan", description=self.flow.slug ) as span: @@ -139,14 +137,16 @@ class FlowPlanner: # Bit of a workaround here, if there is a pending user set in the default context # we use that user for our cache key # to make sure they don't get the generic response - if PLAN_CONTEXT_PENDING_USER not in default_context: - default_context[PLAN_CONTEXT_PENDING_USER] = request.user - user = default_context[PLAN_CONTEXT_PENDING_USER] + if default_context and PLAN_CONTEXT_PENDING_USER in default_context: + user = default_context[PLAN_CONTEXT_PENDING_USER] + else: + user = request.user # First off, check the flow's direct policy bindings # to make sure the user even has access to the flow engine = PolicyEngine(self.flow, user, request) - span.set_data("default_context", cleanse_dict(default_context)) - engine.request.context = default_context + if default_context: + span.set_data("default_context", cleanse_dict(default_context)) + engine.request.context = default_context engine.build() result = engine.result if not result.passing: diff --git a/authentik/stages/authenticator_validate/stage.py b/authentik/stages/authenticator_validate/stage.py index 40294be04..2308fcbb0 100644 --- a/authentik/stages/authenticator_validate/stage.py +++ b/authentik/stages/authenticator_validate/stage.py @@ -1,5 +1,4 @@ """Authenticator Validation""" -from django.contrib.auth.models import AnonymousUser from django.http import HttpRequest, HttpResponse from django_otp import devices_for_user from rest_framework.fields import CharField, IntegerField, JSONField, ListField, UUIDField @@ -280,7 +279,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): def challenge_valid(self, response: AuthenticatorValidationChallengeResponse) -> HttpResponse: # All validation is done by the serializer user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER) - if not user or isinstance(user, AnonymousUser): + if not user: webauthn_device: WebAuthnDevice = response.data.get("webauthn", None) if not webauthn_device: return self.executor.stage_ok() diff --git a/authentik/stages/user_write/stage.py b/authentik/stages/user_write/stage.py index 51d774085..369478edf 100644 --- a/authentik/stages/user_write/stage.py +++ b/authentik/stages/user_write/stage.py @@ -3,7 +3,6 @@ from typing import Any from django.contrib import messages from django.contrib.auth import update_session_auth_hash -from django.contrib.auth.models import AnonymousUser from django.db import transaction from django.db.utils import IntegrityError from django.http import HttpRequest, HttpResponse @@ -26,15 +25,16 @@ LOGGER = get_logger() class UserWriteStageView(StageView): """Finalise Enrollment flow by creating a user object.""" - def write_attribute(self, user: User, key: str, value: Any): + @staticmethod + def write_attribute(user: User, key: str, value: Any): """Allow use of attributes.foo.bar when writing to a user, with full recursion""" parts = key.replace("_", ".").split(".") if len(parts) < 1: # pragma: no cover return - # Function will always be called with a key like attribute. + # Function will always be called with a key like attributes. # this is just a sanity check to ensure that is removed - if parts[0] == "attribute": + if parts[0] == "attributes": parts = parts[1:] attrs = user.attributes for comp in parts[:-1]: @@ -57,12 +57,7 @@ class UserWriteStageView(StageView): return self.executor.stage_invalid() data = self.executor.plan.context[PLAN_CONTEXT_PROMPT] user_created = False - # check if pending user is set (default to anonymous user), if - # it's an anonymous user then we need to create a new user. - if isinstance( - self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER, AnonymousUser()), - AnonymousUser, - ): + if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context: self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = User( is_active=not self.executor.current_stage.create_users_as_inactive ) @@ -90,16 +85,20 @@ class UserWriteStageView(StageView): setter = getattr(user, setter_name) if callable(setter): setter(value) + # For exact attributes match, update the dictionary in place + elif key == "attributes": + user.attributes.update(value) # User has this key already - elif hasattr(user, key): + elif hasattr(user, key) and not key.startswith("attributes."): setattr(user, key, value) # Otherwise we just save it as custom attribute, but only if the value is prefixed with # `attribute_`, to prevent accidentally saving values else: - if not key.startswith("attribute.") and not key.startswith("attribute_"): + if not key.startswith("attributes.") and not key.startswith("attributes_"): LOGGER.debug("discarding key", key=key) continue - self.write_attribute(user, key, value) + UserWriteStageView.write_attribute(user, key, value) + print(user.attributes) # Extra check to prevent flows from saving a user with a blank username if user.username == "": LOGGER.warning("Aborting write to empty username", user=user) diff --git a/authentik/stages/user_write/tests.py b/authentik/stages/user_write/tests.py index 7c3abcedb..349362833 100644 --- a/authentik/stages/user_write/tests.py +++ b/authentik/stages/user_write/tests.py @@ -16,6 +16,7 @@ from authentik.flows.tests.test_executor import TO_STAGE_RESPONSE_MOCK from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT from authentik.stages.user_write.models import UserWriteStage +from authentik.stages.user_write.stage import UserWriteStageView class TestUserWriteStage(FlowTestCase): @@ -77,7 +78,7 @@ class TestUserWriteStage(FlowTestCase): plan.context[PLAN_CONTEXT_PROMPT] = { "username": "test-user-new", "password": new_password, - "attribute.some.custom-attribute": "test", + "attributes.some.custom-attribute": "test", "some_ignored_attribute": "bar", } session = self.client.session @@ -172,3 +173,43 @@ class TestUserWriteStage(FlowTestCase): self.flow, component="ak-stage-access-denied", ) + + def test_write_attribute(self): + """Test write_attribute""" + user = create_test_admin_user() + user.attributes = { + "foo": "bar", + "baz": { + "qwer": [ + "quox", + ] + }, + } + user.save() + UserWriteStageView.write_attribute(user, "attributes.foo", "baz") + self.assertEqual( + user.attributes, + { + "foo": "baz", + "baz": { + "qwer": [ + "quox", + ] + }, + }, + ) + UserWriteStageView.write_attribute(user, "attributes.foob.bar", "baz") + self.assertEqual( + user.attributes, + { + "foo": "baz", + "foob": { + "bar": "baz", + }, + "baz": { + "qwer": [ + "quox", + ] + }, + }, + ) diff --git a/web/src/user/user-settings/details/UserSettingsFlowExecutor.ts b/web/src/user/user-settings/details/UserSettingsFlowExecutor.ts index 838c9a91a..a9e9f1cca 100644 --- a/web/src/user/user-settings/details/UserSettingsFlowExecutor.ts +++ b/web/src/user/user-settings/details/UserSettingsFlowExecutor.ts @@ -101,25 +101,33 @@ export class UserSettingsFlowExecutor extends LitElement implements StageHost { if (!this.flowSlug) { return; } - this.loading = true; - new FlowsApi(DEFAULT_CONFIG) - .flowsExecutorGet({ - flowSlug: this.flowSlug, - query: window.location.search.substring(1), - }) - .then((challenge) => { - this.challenge = challenge; - }) - .catch((e: Error | Response) => { - // Catch JSON or Update errors - this.errorMessage(e); - }) - .finally(() => { - this.loading = false; - }); + new FlowsApi(DEFAULT_CONFIG).flowsInstancesExecuteRetrieve({ + slug: this.flowSlug || "", + }).then(() => { + this.nextChallenge(); + }) }); } + nextChallenge(): void { + this.loading = true; + new FlowsApi(DEFAULT_CONFIG) + .flowsExecutorGet({ + flowSlug: this.flowSlug || "", + query: window.location.search.substring(1), + }) + .then((challenge) => { + this.challenge = challenge; + }) + .catch((e: Error | Response) => { + // Catch JSON or Update errors + this.errorMessage(e); + }) + .finally(() => { + this.loading = false; + }); + } + async errorMessage(error: Error | Response): Promise { let body = ""; if (error instanceof Error) {