diff --git a/passbook/core/signals.py b/passbook/core/signals.py index 01299f90e..74b6b49f1 100644 --- a/passbook/core/signals.py +++ b/passbook/core/signals.py @@ -1,31 +1,7 @@ """passbook core signals""" -from django.core.cache import cache from django.core.signals import Signal -from django.db.models.signals import post_save -from django.dispatch import receiver -from structlog import get_logger - -LOGGER = get_logger() user_signed_up = Signal(providing_args=["request", "user"]) invitation_created = Signal(providing_args=["request", "invitation"]) invitation_used = Signal(providing_args=["request", "invitation", "user"]) password_changed = Signal(providing_args=["user", "password"]) - - -@receiver(post_save) -# pylint: disable=unused-argument -def invalidate_policy_cache(sender, instance, **_): - """Invalidate Policy cache when policy is updated""" - from passbook.policies.models import Policy, PolicyBinding - from passbook.policies.process import cache_key - - if isinstance(instance, Policy): - LOGGER.debug("Invalidating policy cache", policy=instance) - total = 0 - for binding in PolicyBinding.objects.filter(policy=instance): - prefix = cache_key(binding) + "*" - keys = cache.keys(prefix) - total += len(keys) - cache.delete_many(keys) - LOGGER.debug("Deleted keys", len=total) diff --git a/passbook/crypto/migrations/0002_create_self_signed_kp.py b/passbook/crypto/migrations/0002_create_self_signed_kp.py index 645db1dfe..66239b816 100644 --- a/passbook/crypto/migrations/0002_create_self_signed_kp.py +++ b/passbook/crypto/migrations/0002_create_self_signed_kp.py @@ -20,9 +20,7 @@ def create_self_signed(apps, schema_editor): class Migration(migrations.Migration): dependencies = [ - ('passbook_crypto', '0001_initial'), + ("passbook_crypto", "0001_initial"), ] - operations = [ - migrations.RunPython(create_self_signed) - ] + operations = [migrations.RunPython(create_self_signed)] diff --git a/passbook/policies/apps.py b/passbook/policies/apps.py index 5795355b6..946f84609 100644 --- a/passbook/policies/apps.py +++ b/passbook/policies/apps.py @@ -1,4 +1,6 @@ """passbook policies app config""" +from importlib import import_module + from django.apps import AppConfig @@ -8,3 +10,7 @@ class PassbookPoliciesConfig(AppConfig): name = "passbook.policies" label = "passbook_policies" verbose_name = "passbook Policies" + + def ready(self): + """Load source_types from config file""" + import_module("passbook.policies.signals") diff --git a/passbook/policies/engine.py b/passbook/policies/engine.py index c465b5ae5..34dc6d54d 100644 --- a/passbook/policies/engine.py +++ b/passbook/policies/engine.py @@ -73,17 +73,16 @@ class PolicyEngine: """Build task group""" for binding in self._iter_bindings(): self._check_policy_type(binding.policy) - policy = binding.policy - key = cache_key(binding, self.request.user) + key = cache_key(binding, self.request) cached_policy = cache.get(key, None) if cached_policy and self.use_cache: - LOGGER.debug("P_ENG: Taking result from cache", policy=policy, cache_key=key) + LOGGER.debug("P_ENG: Taking result from cache", policy=binding.policy, cache_key=key) self.__cached_policies.append(cached_policy) continue - LOGGER.debug("P_ENG: Evaluating policy", policy=policy) + LOGGER.debug("P_ENG: Evaluating policy", policy=binding.policy) our_end, task_end = Pipe(False) task = PolicyProcess(binding, self.request, task_end) - LOGGER.debug("P_ENG: Starting Process", policy=policy) + LOGGER.debug("P_ENG: Starting Process", policy=binding.policy) task.start() self.__processes.append( PolicyProcessInfo(process=task, connection=our_end, binding=binding) diff --git a/passbook/policies/process.py b/passbook/policies/process.py index 1fb906c9f..3d5ce2525 100644 --- a/passbook/policies/process.py +++ b/passbook/policies/process.py @@ -14,11 +14,13 @@ from passbook.policies.types import PolicyRequest, PolicyResult LOGGER = get_logger() -def cache_key(binding: PolicyBinding, user: Optional[User] = None) -> str: +def cache_key(binding: PolicyBinding, request: PolicyRequest) -> str: """Generate Cache key for policy""" prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}" - if user: - prefix += f"#{user.pk}" + if request.http_request: + prefix += f"_{request.http_request.session.session_key}" + if request.user: + prefix += f"#{request.user.pk}" return prefix @@ -65,7 +67,7 @@ class PolicyProcess(Process): passing=policy_result.passing, user=self.request.user, ) - key = cache_key(self.binding, self.request.user) + key = cache_key(self.binding, self.request) cache.set(key, policy_result) LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key) return policy_result diff --git a/passbook/policies/signals.py b/passbook/policies/signals.py new file mode 100644 index 000000000..0f1b1a6d6 --- /dev/null +++ b/passbook/policies/signals.py @@ -0,0 +1,25 @@ +"""passbook policy signals""" +from django.core.cache import cache +from django.db.models.signals import post_save +from django.dispatch import receiver +from structlog import get_logger + +LOGGER = get_logger() + + +@receiver(post_save) +# pylint: disable=unused-argument +def invalidate_policy_cache(sender, instance, **_): + """Invalidate Policy cache when policy is updated""" + from passbook.policies.models import Policy, PolicyBinding + from passbook.policies.process import cache_key + + if isinstance(instance, Policy): + LOGGER.debug("Invalidating policy cache", policy=instance) + total = 0 + for binding in PolicyBinding.objects.filter(policy=instance): + prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}" + "*" + keys = cache.keys(prefix) + total += len(keys) + cache.delete_many(keys) + LOGGER.debug("Deleted keys", len=total) diff --git a/passbook/stages/user_write/tests.py b/passbook/stages/user_write/tests.py index 5bad06809..d37012207 100644 --- a/passbook/stages/user_write/tests.py +++ b/passbook/stages/user_write/tests.py @@ -72,6 +72,7 @@ class TestUserWriteStage(TestCase): plan.context[PLAN_CONTEXT_PROMPT] = { "username": "test-user-new", "password": new_password, + "some-custom-attribute": "test", } session = self.client.session session[SESSION_KEY_PLAN] = plan @@ -88,6 +89,7 @@ class TestUserWriteStage(TestCase): ) self.assertTrue(user_qs.exists()) self.assertTrue(user_qs.first().check_password(new_password)) + self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test") def test_without_data(self): """Test without data results in error"""