flows/importer: fix validate writing to database not being reverted

This commit is contained in:
Jens Langhammer 2020-09-09 17:21:43 +02:00
parent 18f42a0edf
commit 50a5959f6c
1 changed files with 25 additions and 15 deletions

View File

@ -1,4 +1,5 @@
"""Flow importer"""
from contextlib import contextmanager
from copy import deepcopy
from json import loads
from typing import Any, Dict
@ -9,6 +10,7 @@ from django.apps import apps
from django.db import transaction
from django.db.models import Model
from django.db.models.query_utils import Q
from django.db.utils import IntegrityError
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer, Serializer
from structlog import BoundLogger, get_logger
@ -26,6 +28,15 @@ from passbook.stages.prompt.models import Prompt
ALLOWED_MODELS = (Flow, FlowStageBinding, Stage, Policy, PolicyBinding, Prompt)
@contextmanager
def transaction_rollback():
"""Enters an atomic transaction and always triggers a rollback at the end of the block."""
atomic = transaction.atomic()
atomic.__enter__()
yield
atomic.__exit__(IntegrityError, None, None)
class FlowImporter:
"""Import Flow from json"""
@ -46,11 +57,10 @@ class FlowImporter:
def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
"""Replace any value if it is a known primary key of an other object"""
def updater(value) -> Any:
if value in self.__pk_map:
self.logger.debug(
"updating reference in entry", value=value
)
self.logger.debug("updating reference in entry", value=value)
return self.__pk_map[value]
return value
@ -118,14 +128,15 @@ class FlowImporter:
def apply(self) -> bool:
"""Apply (create/update) flow json, in database transaction"""
sid = transaction.savepoint()
successful = self._apply_models()
if not successful:
self.logger.debug("Reverting changes due to error")
transaction.savepoint_rollback(sid)
try:
with transaction.atomic():
if not self._apply_models():
self.logger.debug("Reverting changes due to error")
raise IntegrityError
except IntegrityError:
return False
self.logger.debug("Committing changes")
transaction.savepoint_commit(sid)
else:
self.logger.debug("Committing changes")
return True
def _apply_models(self) -> bool:
@ -154,9 +165,8 @@ class FlowImporter:
if self.__import.version != 1:
self.logger.warning("Invalid bundle version")
return False
sid = transaction.savepoint()
successful = self._apply_models()
if not successful:
self.logger.debug("Flow validation failed")
transaction.savepoint_rollback(sid)
with transaction_rollback():
successful = self._apply_models()
if not successful:
self.logger.debug("Flow validation failed")
return successful