flows: fix exporting and importing for models with multiple unique fields
This commit is contained in:
parent
268de20872
commit
dd017e7190
|
@ -0,0 +1,111 @@
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"entries": [
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"slug": "default-authentication-flow",
|
||||||
|
"pk": "563ece21-e9a4-47e5-a264-23ffd923e393"
|
||||||
|
},
|
||||||
|
"model": "passbook_flows.flow",
|
||||||
|
"attrs": {
|
||||||
|
"name": "Default Authentication Flow",
|
||||||
|
"title": "Welcome to passbook!",
|
||||||
|
"designation": "authentication"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "69d41125-3987-499b-8d74-ef27b54b88c8",
|
||||||
|
"name": "default-authentication-login"
|
||||||
|
},
|
||||||
|
"model": "passbook_stages_user_login.userloginstage",
|
||||||
|
"attrs": {
|
||||||
|
"session_duration": 0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "5f594f27-0def-488d-9855-fe604eb13de5",
|
||||||
|
"name": "default-authentication-identification"
|
||||||
|
},
|
||||||
|
"model": "passbook_stages_identification.identificationstage",
|
||||||
|
"attrs": {
|
||||||
|
"user_fields": [
|
||||||
|
"email",
|
||||||
|
"username"
|
||||||
|
],
|
||||||
|
"template": "stages/identification/login.html",
|
||||||
|
"enrollment_flow": null,
|
||||||
|
"recovery_flow": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "37f709c3-8817-45e8-9a93-80a925d293c2",
|
||||||
|
"name": "default-authentication-flow-totp"
|
||||||
|
},
|
||||||
|
"model": "passbook_stages_otp_validate.otpvalidatestage",
|
||||||
|
"attrs": {}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "d8affa62-500c-4c5c-a01f-5835e1ffdf40",
|
||||||
|
"name": "default-authentication-password"
|
||||||
|
},
|
||||||
|
"model": "passbook_stages_password.passwordstage",
|
||||||
|
"attrs": {
|
||||||
|
"backends": [
|
||||||
|
"django.contrib.auth.backends.ModelBackend"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "a3056482-b692-4e3a-93f1-7351c6a351c7",
|
||||||
|
"target": "563ece21-e9a4-47e5-a264-23ffd923e393",
|
||||||
|
"stage": "5f594f27-0def-488d-9855-fe604eb13de5",
|
||||||
|
"order": 0
|
||||||
|
},
|
||||||
|
"model": "passbook_flows.flowstagebinding",
|
||||||
|
"attrs": {
|
||||||
|
"re_evaluate_policies": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "4e8538cf-3e18-4a68-82ae-6df6725fa2e6",
|
||||||
|
"target": "563ece21-e9a4-47e5-a264-23ffd923e393",
|
||||||
|
"stage": "d8affa62-500c-4c5c-a01f-5835e1ffdf40",
|
||||||
|
"order": 1
|
||||||
|
},
|
||||||
|
"model": "passbook_flows.flowstagebinding",
|
||||||
|
"attrs": {
|
||||||
|
"re_evaluate_policies": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "688aec6f-5622-42c6-83a5-d22072d7e798",
|
||||||
|
"target": "563ece21-e9a4-47e5-a264-23ffd923e393",
|
||||||
|
"stage": "37f709c3-8817-45e8-9a93-80a925d293c2",
|
||||||
|
"order": 2
|
||||||
|
},
|
||||||
|
"model": "passbook_flows.flowstagebinding",
|
||||||
|
"attrs": {
|
||||||
|
"re_evaluate_policies": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "f3fede3a-a9b5-4232-9ec7-be7ff4194b27",
|
||||||
|
"target": "563ece21-e9a4-47e5-a264-23ffd923e393",
|
||||||
|
"stage": "69d41125-3987-499b-8d74-ef27b54b88c8",
|
||||||
|
"order": 3
|
||||||
|
},
|
||||||
|
"model": "passbook_flows.flowstagebinding",
|
||||||
|
"attrs": {
|
||||||
|
"re_evaluate_policies": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -39,7 +39,7 @@ class FlowForm(forms.ModelForm):
|
||||||
class FlowStageBindingForm(forms.ModelForm):
|
class FlowStageBindingForm(forms.ModelForm):
|
||||||
"""FlowStageBinding Form"""
|
"""FlowStageBinding Form"""
|
||||||
|
|
||||||
stage = GroupedModelChoiceField(queryset=Stage.objects.all().select_subclasses(),)
|
stage = GroupedModelChoiceField(queryset=Stage.objects.all().select_subclasses())
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
"""Apply flow from commandline"""
|
||||||
|
from django.core.management.base import BaseCommand, no_translations
|
||||||
|
|
||||||
|
from passbook.flows.transfer.importer import FlowImporter
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
"""Apply flow from commandline"""
|
||||||
|
|
||||||
|
@no_translations
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
"""Apply all flows in order, abort when one fails to import"""
|
||||||
|
for flow_path in options.get("flows", []):
|
||||||
|
with open(flow_path, "r") as flow_file:
|
||||||
|
importer = FlowImporter(flow_file.read())
|
||||||
|
valid = importer.validate()
|
||||||
|
if not valid:
|
||||||
|
raise ValueError("Flow invalid")
|
||||||
|
importer.apply()
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
parser.add_argument("flows", nargs="+", type=str)
|
|
@ -6,13 +6,13 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||||
|
|
||||||
def add_title_for_defaults(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
def add_title_for_defaults(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
slug_title_map = {
|
slug_title_map = {
|
||||||
"default-authentication-flow": "Default Authentication Flow",
|
"default-authentication-flow": "Welcome to passbook!",
|
||||||
"default-invalidation-flow": "Default Invalidation Flow",
|
"default-invalidation-flow": "Default Invalidation Flow",
|
||||||
"default-source-enrollment": "Default Source Enrollment Flow",
|
"default-source-enrollment": "Welcome to passbook!",
|
||||||
"default-source-authentication": "Default Source Authentication Flow",
|
"default-source-authentication": "Welcome to passbook!",
|
||||||
"default-provider-authorization-implicit-consent": "Default Provider Authorization Flow (implicit consent)",
|
"default-provider-authorization-implicit-consent": "Default Provider Authorization Flow (implicit consent)",
|
||||||
"default-provider-authorization-explicit-consent": "Default Provider Authorization Flow (explicit consent)",
|
"default-provider-authorization-explicit-consent": "Default Provider Authorization Flow (explicit consent)",
|
||||||
"default-password-change": "Default Password Change Flow",
|
"default-password-change": "Change password",
|
||||||
}
|
}
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
Flow = apps.get_model("passbook_flows", "Flow")
|
Flow = apps.get_model("passbook_flows", "Flow")
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Generated by Django 3.1.1 on 2020-09-05 21:42
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("passbook_flows", "0012_auto_20200830_1056"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="stage", name="name", field=models.TextField(unique=True),
|
||||||
|
),
|
||||||
|
]
|
|
@ -46,7 +46,7 @@ class Stage(SerializerModel):
|
||||||
|
|
||||||
stage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
stage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||||
|
|
||||||
name = models.TextField()
|
name = models.TextField(unique=True)
|
||||||
|
|
||||||
objects = InheritanceManager()
|
objects = InheritanceManager()
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ class FlowStageBinding(SerializerModel, PolicyBindingModel):
|
||||||
return FlowStageBindingSerializer
|
return FlowStageBindingSerializer
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"Flow Binding {self.target} -> {self.stage}"
|
return f"'{self.target}' -> '{self.stage}' # {self.order}"
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Test flow transfer"""
|
"""Test flow transfer"""
|
||||||
from json import dumps
|
from json import dumps
|
||||||
|
|
||||||
|
from django.db import transaction
|
||||||
from django.test import TransactionTestCase
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding
|
from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding
|
||||||
|
@ -21,12 +22,13 @@ class TestFlowTransfer(TransactionTestCase):
|
||||||
importer = FlowImporter('{"version": 3}')
|
importer = FlowImporter('{"version": 3}')
|
||||||
self.assertFalse(importer.validate())
|
self.assertFalse(importer.validate())
|
||||||
importer = FlowImporter(
|
importer = FlowImporter(
|
||||||
'{"version": 1,"entries":[{"identifier":"","attrs":{},"model": "passbook_core.User"}]}'
|
'{"version": 1,"entries":[{"identifiers":{},"attrs":{},"model": "passbook_core.User"}]}'
|
||||||
)
|
)
|
||||||
self.assertFalse(importer.validate())
|
self.assertFalse(importer.validate())
|
||||||
|
|
||||||
def test_export_validate_import(self):
|
def test_export_validate_import(self):
|
||||||
"""Test export and validate it"""
|
"""Test export and validate it"""
|
||||||
|
sid = transaction.savepoint()
|
||||||
login_stage = UserLoginStage.objects.create(name="default-authentication-login")
|
login_stage = UserLoginStage.objects.create(name="default-authentication-login")
|
||||||
|
|
||||||
flow = Flow.objects.create(
|
flow = Flow.objects.create(
|
||||||
|
@ -40,42 +42,55 @@ class TestFlowTransfer(TransactionTestCase):
|
||||||
|
|
||||||
exporter = FlowExporter(flow)
|
exporter = FlowExporter(flow)
|
||||||
export = exporter.export()
|
export = exporter.export()
|
||||||
|
|
||||||
|
transaction.savepoint_rollback(sid)
|
||||||
|
|
||||||
self.assertEqual(len(export.entries), 3)
|
self.assertEqual(len(export.entries), 3)
|
||||||
export_json = dumps(export, cls=DataclassEncoder)
|
export_json = dumps(export, cls=DataclassEncoder)
|
||||||
importer = FlowImporter(export_json)
|
importer = FlowImporter(export_json)
|
||||||
self.assertTrue(importer.validate())
|
self.assertTrue(importer.validate())
|
||||||
flow.delete()
|
|
||||||
login_stage.delete()
|
|
||||||
self.assertTrue(importer.apply())
|
self.assertTrue(importer.apply())
|
||||||
|
|
||||||
self.assertTrue(Flow.objects.filter(slug="test").exists())
|
self.assertTrue(Flow.objects.filter(slug="test").exists())
|
||||||
|
|
||||||
def test_export_validate_import_policies(self):
|
def test_export_validate_import_policies(self):
|
||||||
"""Test export and validate it"""
|
"""Test export and validate it"""
|
||||||
|
sid = transaction.savepoint()
|
||||||
|
|
||||||
flow_policy = ExpressionPolicy.objects.create(
|
flow_policy = ExpressionPolicy.objects.create(
|
||||||
name="default-source-authentication-if-sso", expression="return True",
|
name="default-source-authentication-if-sso", expression="return True",
|
||||||
)
|
)
|
||||||
flow = Flow.objects.create(
|
flow = Flow.objects.create(
|
||||||
slug="default-source-authentication",
|
slug="default-source-authentication-test",
|
||||||
designation=FlowDesignation.AUTHENTICATION,
|
designation=FlowDesignation.AUTHENTICATION,
|
||||||
name="Welcome to passbook!",
|
name="Welcome to passbook!",
|
||||||
)
|
)
|
||||||
PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0)
|
PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0)
|
||||||
|
|
||||||
user_login = UserLoginStage.objects.create(
|
user_login = UserLoginStage.objects.create(
|
||||||
name="default-source-authentication-login"
|
name="default-source-authentication-login-test"
|
||||||
)
|
)
|
||||||
FlowStageBinding.objects.create(target=flow, stage=user_login, order=0)
|
FlowStageBinding.objects.create(target=flow, stage=user_login, order=0)
|
||||||
|
|
||||||
exporter = FlowExporter(flow)
|
exporter = FlowExporter(flow)
|
||||||
export = exporter.export()
|
export = exporter.export()
|
||||||
|
|
||||||
|
transaction.savepoint_rollback(sid)
|
||||||
|
|
||||||
export_json = dumps(export, cls=DataclassEncoder)
|
export_json = dumps(export, cls=DataclassEncoder)
|
||||||
importer = FlowImporter(export_json)
|
importer = FlowImporter(export_json)
|
||||||
self.assertTrue(importer.validate())
|
self.assertTrue(importer.validate())
|
||||||
self.assertTrue(importer.apply())
|
self.assertTrue(importer.apply())
|
||||||
|
self.assertTrue(
|
||||||
|
UserLoginStage.objects.filter(
|
||||||
|
name="default-source-authentication-login-test"
|
||||||
|
).exists()
|
||||||
|
)
|
||||||
|
|
||||||
def test_export_validate_import_prompt(self):
|
def test_export_validate_import_prompt(self):
|
||||||
"""Test export and validate it"""
|
"""Test export and validate it"""
|
||||||
|
sid = transaction.savepoint()
|
||||||
|
|
||||||
# First stage fields
|
# First stage fields
|
||||||
username_prompt = Prompt.objects.create(
|
username_prompt = Prompt.objects.create(
|
||||||
field_key="username", label="Username", order=0, type=FieldTypes.TEXT
|
field_key="username", label="Username", order=0, type=FieldTypes.TEXT
|
||||||
|
@ -90,13 +105,13 @@ class TestFlowTransfer(TransactionTestCase):
|
||||||
type=FieldTypes.PASSWORD,
|
type=FieldTypes.PASSWORD,
|
||||||
)
|
)
|
||||||
# Stages
|
# Stages
|
||||||
first_stage = PromptStage.objects.create(name="prompt-stage-first")
|
first_stage = PromptStage.objects.create(name="prompt-stage-first-test")
|
||||||
first_stage.fields.set([username_prompt, password, password_repeat])
|
first_stage.fields.set([username_prompt, password, password_repeat])
|
||||||
first_stage.save()
|
first_stage.save()
|
||||||
|
|
||||||
# Password checking policy
|
# Password checking policy
|
||||||
password_policy = ExpressionPolicy.objects.create(
|
password_policy = ExpressionPolicy.objects.create(
|
||||||
name="policy-enrollment-password-equals",
|
name="policy-enrollment-password-equals-test",
|
||||||
expression="return request.context['password'] == request.context['password_repeat']",
|
expression="return request.context['password'] == request.context['password_repeat']",
|
||||||
)
|
)
|
||||||
PolicyBinding.objects.create(
|
PolicyBinding.objects.create(
|
||||||
|
@ -105,7 +120,7 @@ class TestFlowTransfer(TransactionTestCase):
|
||||||
|
|
||||||
flow = Flow.objects.create(
|
flow = Flow.objects.create(
|
||||||
name="default-enrollment-flow",
|
name="default-enrollment-flow",
|
||||||
slug="default-enrollment-flow",
|
slug="default-enrollment-flow-test",
|
||||||
designation=FlowDesignation.ENROLLMENT,
|
designation=FlowDesignation.ENROLLMENT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -114,6 +129,10 @@ class TestFlowTransfer(TransactionTestCase):
|
||||||
exporter = FlowExporter(flow)
|
exporter = FlowExporter(flow)
|
||||||
export = exporter.export()
|
export = exporter.export()
|
||||||
export_json = dumps(export, cls=DataclassEncoder)
|
export_json = dumps(export, cls=DataclassEncoder)
|
||||||
|
|
||||||
|
transaction.savepoint_rollback(sid)
|
||||||
|
|
||||||
importer = FlowImporter(export_json)
|
importer = FlowImporter(export_json)
|
||||||
|
|
||||||
self.assertTrue(importer.validate())
|
self.assertTrue(importer.validate())
|
||||||
self.assertTrue(importer.apply())
|
self.assertTrue(importer.apply())
|
||||||
|
|
|
@ -11,10 +11,10 @@ from passbook.lib.sentry import SentryIgnoredException
|
||||||
def get_attrs(obj: SerializerModel) -> Dict[str, Any]:
|
def get_attrs(obj: SerializerModel) -> Dict[str, Any]:
|
||||||
"""Get object's attributes via their serializer, and covert it to a normal dict"""
|
"""Get object's attributes via their serializer, and covert it to a normal dict"""
|
||||||
data = dict(obj.serializer(obj).data)
|
data = dict(obj.serializer(obj).data)
|
||||||
if "policies" in data:
|
to_remove = ("policies", "stages", "pk")
|
||||||
data.pop("policies")
|
for to_remove_name in to_remove:
|
||||||
if "stages" in data:
|
if to_remove_name in data:
|
||||||
data.pop("stages")
|
data.pop(to_remove_name)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,17 +22,26 @@ def get_attrs(obj: SerializerModel) -> Dict[str, Any]:
|
||||||
class FlowBundleEntry:
|
class FlowBundleEntry:
|
||||||
"""Single entry of a bundle"""
|
"""Single entry of a bundle"""
|
||||||
|
|
||||||
identifier: str
|
identifiers: Dict[str, Any]
|
||||||
model: str
|
model: str
|
||||||
attrs: Dict[str, Any]
|
attrs: Dict[str, Any]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model(model: SerializerModel) -> "FlowBundleEntry":
|
def from_model(
|
||||||
|
model: SerializerModel, *extra_identifier_names: str
|
||||||
|
) -> "FlowBundleEntry":
|
||||||
"""Convert a SerializerModel instance to a Bundle Entry"""
|
"""Convert a SerializerModel instance to a Bundle Entry"""
|
||||||
|
identifiers = {
|
||||||
|
"pk": model.pk,
|
||||||
|
}
|
||||||
|
all_attrs = get_attrs(model)
|
||||||
|
|
||||||
|
for extra_identifier_name in extra_identifier_names:
|
||||||
|
identifiers[extra_identifier_name] = all_attrs.pop(extra_identifier_name)
|
||||||
return FlowBundleEntry(
|
return FlowBundleEntry(
|
||||||
identifier=model.pk,
|
identifiers=identifiers,
|
||||||
model=f"{model._meta.app_label}.{model._meta.model_name}",
|
model=f"{model._meta.app_label}.{model._meta.model_name}",
|
||||||
attrs=get_attrs(model),
|
attrs=all_attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,13 +28,13 @@ class FlowExporter:
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
if isinstance(stage, PromptStage):
|
if isinstance(stage, PromptStage):
|
||||||
pass
|
pass
|
||||||
yield FlowBundleEntry.from_model(stage)
|
yield FlowBundleEntry.from_model(stage, "name")
|
||||||
|
|
||||||
def walk_stage_bindings(self) -> Iterator[FlowBundleEntry]:
|
def walk_stage_bindings(self) -> Iterator[FlowBundleEntry]:
|
||||||
"""Convert all bindings attached to self.flow into FlowBundleEntry objects"""
|
"""Convert all bindings attached to self.flow into FlowBundleEntry objects"""
|
||||||
bindings = FlowStageBinding.objects.filter(target=self.flow).select_related()
|
bindings = FlowStageBinding.objects.filter(target=self.flow).select_related()
|
||||||
for binding in bindings:
|
for binding in bindings:
|
||||||
yield FlowBundleEntry.from_model(binding)
|
yield FlowBundleEntry.from_model(binding, "target", "stage", "order")
|
||||||
|
|
||||||
def walk_policies(self) -> Iterator[FlowBundleEntry]:
|
def walk_policies(self) -> Iterator[FlowBundleEntry]:
|
||||||
"""Walk over all policies and their respective bindings"""
|
"""Walk over all policies and their respective bindings"""
|
||||||
|
@ -64,7 +64,7 @@ class FlowExporter:
|
||||||
def export(self) -> FlowBundle:
|
def export(self) -> FlowBundle:
|
||||||
"""Create a list of all objects including the flow"""
|
"""Create a list of all objects including the flow"""
|
||||||
bundle = FlowBundle()
|
bundle = FlowBundle()
|
||||||
bundle.entries.append(FlowBundleEntry.from_model(self.flow))
|
bundle.entries.append(FlowBundleEntry.from_model(self.flow, "slug"))
|
||||||
if self.with_stage_prompts:
|
if self.with_stage_prompts:
|
||||||
bundle.entries.extend(self.walk_stage_prompts())
|
bundle.entries.extend(self.walk_stage_prompts())
|
||||||
bundle.entries.extend(self.walk_stages())
|
bundle.entries.extend(self.walk_stages())
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
"""Flow importer"""
|
"""Flow importer"""
|
||||||
from json import loads
|
from json import loads
|
||||||
from typing import Type
|
from typing import Any, Dict
|
||||||
|
|
||||||
from dacite import from_dict
|
from dacite import from_dict
|
||||||
from dacite.exceptions import DaciteError
|
from dacite.exceptions import DaciteError
|
||||||
from django.apps import apps
|
from django.apps import apps
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
|
from django.db.models.query_utils import Q
|
||||||
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.serializers import BaseSerializer, Serializer
|
from rest_framework.serializers import BaseSerializer, Serializer
|
||||||
from structlog import BoundLogger, get_logger
|
from structlog import BoundLogger, get_logger
|
||||||
|
|
||||||
|
@ -17,7 +19,7 @@ from passbook.flows.transfer.common import (
|
||||||
FlowBundleEntry,
|
FlowBundleEntry,
|
||||||
)
|
)
|
||||||
from passbook.lib.models import SerializerModel
|
from passbook.lib.models import SerializerModel
|
||||||
from passbook.policies.models import Policy, PolicyBinding, PolicyBindingModel
|
from passbook.policies.models import Policy, PolicyBinding
|
||||||
from passbook.stages.prompt.models import Prompt
|
from passbook.stages.prompt.models import Prompt
|
||||||
|
|
||||||
ALLOWED_MODELS = (Flow, FlowStageBinding, Stage, Policy, PolicyBinding, Prompt)
|
ALLOWED_MODELS = (Flow, FlowStageBinding, Stage, Policy, PolicyBinding, Prompt)
|
||||||
|
@ -28,49 +30,42 @@ class FlowImporter:
|
||||||
|
|
||||||
__import: FlowBundle
|
__import: FlowBundle
|
||||||
|
|
||||||
|
__pk_map: Dict[Any, Model]
|
||||||
|
|
||||||
logger: BoundLogger
|
logger: BoundLogger
|
||||||
|
|
||||||
def __init__(self, json_input: str):
|
def __init__(self, json_input: str):
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
self.__pk_map = {}
|
||||||
import_dict = loads(json_input)
|
import_dict = loads(json_input)
|
||||||
try:
|
try:
|
||||||
self.__import = from_dict(FlowBundle, import_dict)
|
self.__import = from_dict(FlowBundle, import_dict)
|
||||||
except DaciteError as exc:
|
except DaciteError as exc:
|
||||||
raise EntryInvalidError from exc
|
raise EntryInvalidError from exc
|
||||||
|
|
||||||
def validate(self) -> bool:
|
def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate loaded flow export, ensure all models are allowed
|
"""Replace any value if it is a known primary key of an other object"""
|
||||||
and serializers have no errors"""
|
for key, value in attrs.items():
|
||||||
if self.__import.version != 1:
|
if isinstance(value, (list, dict)):
|
||||||
self.logger.warning("Invalid bundle version")
|
|
||||||
return False
|
|
||||||
for entry in self.__import.entries:
|
|
||||||
try:
|
|
||||||
self._validate_single(entry)
|
|
||||||
except EntryInvalidError as exc:
|
|
||||||
self.logger.warning(exc)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __get_pk_filed(self, model_class: Type[Model]) -> str:
|
|
||||||
fields = model_class._meta.get_fields()
|
|
||||||
pks = []
|
|
||||||
for field in fields:
|
|
||||||
# Ignore base PK from pbm as that isn't the same pk we exported
|
|
||||||
if field.model in [PolicyBindingModel]:
|
|
||||||
continue
|
continue
|
||||||
# Ignore primary keys with _ptr suffix as those are surrogate and not what we exported
|
if value in self.__pk_map:
|
||||||
if field.name.endswith("_ptr"):
|
attrs[key] = self.__pk_map[value]
|
||||||
continue
|
|
||||||
if hasattr(field, "primary_key"):
|
|
||||||
if field.primary_key:
|
|
||||||
pks.append(field.name)
|
|
||||||
if len(pks) > 1:
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Found more than one fields with primary_key=True, using pk", pks=pks
|
"updating reference in entry", key=key, new_value=attrs[key]
|
||||||
)
|
)
|
||||||
return "pk"
|
return attrs
|
||||||
return pks[0]
|
|
||||||
|
def __query_from_identifier(self, attrs: Dict[str, Any]) -> Q:
|
||||||
|
"""Generate an or'd query from all identifiers in an entry"""
|
||||||
|
# Since identifiers can also be pk-references to other objects (see FlowStageBinding)
|
||||||
|
# we have to ensure those references are also replaced
|
||||||
|
main_query = Q(pk=attrs["pk"])
|
||||||
|
sub_query = Q()
|
||||||
|
for identifier, value in attrs.items():
|
||||||
|
if identifier == "pk":
|
||||||
|
continue
|
||||||
|
sub_query &= Q(**{identifier: value})
|
||||||
|
return main_query | sub_query
|
||||||
|
|
||||||
def _validate_single(self, entry: FlowBundleEntry) -> BaseSerializer:
|
def _validate_single(self, entry: FlowBundleEntry) -> BaseSerializer:
|
||||||
"""Validate a single entry"""
|
"""Validate a single entry"""
|
||||||
|
@ -82,43 +77,53 @@ class FlowImporter:
|
||||||
# If we try to validate without referencing a possible instance
|
# If we try to validate without referencing a possible instance
|
||||||
# we'll get a duplicate error, hence we load the model here and return
|
# we'll get a duplicate error, hence we load the model here and return
|
||||||
# the full serializer for later usage
|
# the full serializer for later usage
|
||||||
existing_models = model.objects.filter(pk=entry.identifier)
|
# Because a model might have multiple unique columns, we chain all identifiers together
|
||||||
serializer_kwargs = {"data": entry.attrs}
|
# to create an OR query.
|
||||||
if existing_models.exists():
|
updated_identifiers = self.__update_pks_for_attrs(entry.identifiers)
|
||||||
self.logger.debug(
|
existing_models = model.objects.filter(
|
||||||
"initialise serializer with instance", instance=existing_models.first()
|
self.__query_from_identifier(updated_identifiers)
|
||||||
)
|
)
|
||||||
serializer_kwargs["instance"] = existing_models.first()
|
|
||||||
|
serializer_kwargs = {}
|
||||||
|
if existing_models.exists():
|
||||||
|
model_instance = existing_models.first()
|
||||||
|
self.logger.debug(
|
||||||
|
"initialise serializer with instance",
|
||||||
|
model=model,
|
||||||
|
instance=model_instance,
|
||||||
|
pk=model_instance.pk,
|
||||||
|
)
|
||||||
|
serializer_kwargs["instance"] = model_instance
|
||||||
else:
|
else:
|
||||||
self.logger.debug("initialise new instance", pk=entry.identifier)
|
self.logger.debug(
|
||||||
|
"initialise new instance", model=model, **updated_identifiers
|
||||||
|
)
|
||||||
|
full_data = self.__update_pks_for_attrs(entry.attrs)
|
||||||
|
full_data.update(updated_identifiers)
|
||||||
|
serializer_kwargs["data"] = full_data
|
||||||
|
|
||||||
serializer: Serializer = model().serializer(**serializer_kwargs)
|
serializer: Serializer = model().serializer(**serializer_kwargs)
|
||||||
is_valid = serializer.is_valid()
|
try:
|
||||||
if not is_valid:
|
serializer.is_valid(raise_exception=True)
|
||||||
raise EntryInvalidError(f"Serializer errors {serializer.errors}")
|
except ValidationError as exc:
|
||||||
if not existing_models.exists():
|
raise EntryInvalidError(f"Serializer errors {serializer.errors}") from exc
|
||||||
# only insert the PK if we're creating a new model, otherwise we get
|
|
||||||
# an integrity error
|
|
||||||
model_pk = self.__get_pk_filed(model)
|
|
||||||
serializer.validated_data[model_pk] = entry.identifier
|
|
||||||
return serializer
|
return serializer
|
||||||
|
|
||||||
def apply(self) -> bool:
|
def apply(self) -> bool:
|
||||||
"""Apply (create/update) flow json, in database transaction"""
|
"""Apply (create/update) flow json, in database transaction"""
|
||||||
transaction.set_autocommit(False)
|
sid = transaction.savepoint()
|
||||||
successful = self._apply_models()
|
successful = self._apply_models()
|
||||||
if not successful:
|
if not successful:
|
||||||
self.logger.debug("Reverting changes due to error")
|
self.logger.debug("Reverting changes due to error")
|
||||||
transaction.rollback()
|
transaction.savepoint_rollback(sid)
|
||||||
transaction.set_autocommit(True)
|
|
||||||
return False
|
return False
|
||||||
self.logger.debug("Committing changes")
|
self.logger.debug("Committing changes")
|
||||||
transaction.commit()
|
transaction.savepoint_commit(sid)
|
||||||
transaction.set_autocommit(True)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _apply_models(self) -> bool:
|
def _apply_models(self) -> bool:
|
||||||
"""Apply (create/update) flow json"""
|
"""Apply (create/update) flow json"""
|
||||||
|
self.__pk_map = {}
|
||||||
for entry in self.__import.entries:
|
for entry in self.__import.entries:
|
||||||
model_app_label, model_name = entry.model.split(".")
|
model_app_label, model_name = entry.model.split(".")
|
||||||
model: SerializerModel = apps.get_model(model_app_label, model_name)
|
model: SerializerModel = apps.get_model(model_app_label, model_name)
|
||||||
|
@ -130,5 +135,20 @@ class FlowImporter:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
model = serializer.save()
|
model = serializer.save()
|
||||||
|
self.__pk_map[entry.identifiers["pk"]] = model.pk
|
||||||
self.logger.debug("updated model", model=model, pk=model.pk)
|
self.logger.debug("updated model", model=model, pk=model.pk)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""Validate loaded flow export, ensure all models are allowed
|
||||||
|
and serializers have no errors"""
|
||||||
|
self.logger.debug("Starting flow import validaton")
|
||||||
|
if self.__import.version != 1:
|
||||||
|
self.logger.warning("Invalid bundle version")
|
||||||
|
return False
|
||||||
|
sid = transaction.savepoint()
|
||||||
|
successful = self._apply_models()
|
||||||
|
if not successful:
|
||||||
|
self.logger.debug("Flow validation failed")
|
||||||
|
transaction.savepoint_rollback(sid)
|
||||||
|
return successful
|
||||||
|
|
|
@ -54,5 +54,5 @@ def _send_update(outpost_model: Model):
|
||||||
for outpost in outpost_model.outpost_set.all():
|
for outpost in outpost_model.outpost_set.all():
|
||||||
channel_layer = get_channel_layer()
|
channel_layer = get_channel_layer()
|
||||||
for channel in outpost.channels:
|
for channel in outpost.channels:
|
||||||
print(f"sending update to channel {channel}")
|
LOGGER.debug("sending update", channel=channel)
|
||||||
async_to_sync(channel_layer.send)(channel, {"type": "event.update"})
|
async_to_sync(channel_layer.send)(channel, {"type": "event.update"})
|
||||||
|
|
|
@ -376,6 +376,7 @@ _LOGGING_HANDLER_MAP = {
|
||||||
"docker": "WARNING",
|
"docker": "WARNING",
|
||||||
"urllib3": "WARNING",
|
"urllib3": "WARNING",
|
||||||
"websockets": "WARNING",
|
"websockets": "WARNING",
|
||||||
|
"daphne": "WARNING",
|
||||||
}
|
}
|
||||||
for handler_name, level in _LOGGING_HANDLER_MAP.items():
|
for handler_name, level in _LOGGING_HANDLER_MAP.items():
|
||||||
# pyright: reportGeneralTypeIssues=false
|
# pyright: reportGeneralTypeIssues=false
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""Prompt Stage API Views"""
|
"""Prompt Stage API Views"""
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import CharField, ModelSerializer
|
||||||
|
from rest_framework.validators import UniqueValidator
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from passbook.stages.prompt.models import Prompt, PromptStage
|
from passbook.stages.prompt.models import Prompt, PromptStage
|
||||||
|
@ -8,6 +9,8 @@ from passbook.stages.prompt.models import Prompt, PromptStage
|
||||||
class PromptStageSerializer(ModelSerializer):
|
class PromptStageSerializer(ModelSerializer):
|
||||||
"""PromptStage Serializer"""
|
"""PromptStage Serializer"""
|
||||||
|
|
||||||
|
name = CharField(validators=[UniqueValidator(queryset=PromptStage.objects.all())])
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
model = PromptStage
|
model = PromptStage
|
||||||
|
|
Reference in New Issue