diff --git a/passbook/flows/transfer/exporter.py b/passbook/flows/transfer/exporter.py index 44ca05b75..9e97f3ac8 100644 --- a/passbook/flows/transfer/exporter.py +++ b/passbook/flows/transfer/exporter.py @@ -1,6 +1,9 @@ """Flow exporter""" from json import dumps -from typing import Iterator +from typing import Iterator, List +from uuid import UUID + +from django.db.models import Q from passbook.flows.models import Flow, FlowStageBinding, Stage from passbook.flows.transfer.common import DataclassEncoder, FlowBundle, FlowBundleEntry @@ -15,11 +18,24 @@ class FlowExporter: with_policies: bool with_stage_prompts: bool + pbm_uuids: List[UUID] + def __init__(self, flow: Flow): self.flow = flow self.with_policies = True self.with_stage_prompts = True + def _prepare_pbm(self): + self.pbm_uuids = [self.flow.pbm_uuid] + for stage_subclass in Stage.__subclasses__(): + if issubclass(stage_subclass, PolicyBindingModel): + self.pbm_uuids += stage_subclass.objects.filter( + flow=self.flow + ).values_list("pbm_uuid", flat=True) + self.pbm_uuids += FlowStageBinding.objects.filter(target=self.flow).values_list( + "pbm_uuid", flat=True + ) + def walk_stages(self) -> Iterator[FlowBundleEntry]: """Convert all stages attached to self.flow into FlowBundleEntry objects""" stages = ( @@ -37,21 +53,24 @@ class FlowExporter: yield FlowBundleEntry.from_model(binding, "target", "stage", "order") def walk_policies(self) -> Iterator[FlowBundleEntry]: - """Walk over all policies and their respective bindings""" - pbm_uuids = [self.flow.pbm_uuid] - for stage_subclass in Stage.__subclasses__(): - if issubclass(stage_subclass, PolicyBindingModel): - pbm_uuids += stage_subclass.objects.filter(flow=self.flow).values_list( - "pbm_uuid", flat=True - ) - pbm_uuids += FlowStageBinding.objects.filter(target=self.flow).values_list( - "pbm_uuid", flat=True + """Walk over all policies. This is done at the beginning of the export for stages that have + a direct foreign key to a policy.""" + # Special case for PromptStage as that has a direct M2M to policy, we have to ensure + # all policies referenced in there we also include here + prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list( + "pk", flat=True ) - # Add policy objects first, so they are created first - policies = Policy.objects.filter(bindings__in=pbm_uuids).select_related() + query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages) + policies = Policy.objects.filter(query).select_related() for policy in policies: yield FlowBundleEntry.from_model(policy) - bindings = PolicyBinding.objects.filter(target__in=pbm_uuids).select_related() + + def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]: + """Walk over all policybindings relative to us. This is run at the end of the export, as + we are sure all objects exist now.""" + bindings = PolicyBinding.objects.filter( + target__in=self.pbm_uuids + ).select_related() for binding in bindings: yield FlowBundleEntry.from_model(binding, "policy", "target", "order") @@ -64,14 +83,18 @@ class FlowExporter: def export(self) -> FlowBundle: """Create a list of all objects including the flow""" + if self.with_policies: + self._prepare_pbm() bundle = FlowBundle() bundle.entries.append(FlowBundleEntry.from_model(self.flow, "slug")) if self.with_stage_prompts: bundle.entries.extend(self.walk_stage_prompts()) + if self.with_policies: + bundle.entries.extend(self.walk_policies()) bundle.entries.extend(self.walk_stages()) bundle.entries.extend(self.walk_stage_bindings()) if self.with_policies: - bundle.entries.extend(self.walk_policies()) + bundle.entries.extend(self.walk_policy_bindings()) return bundle def export_to_string(self) -> str: