From 50a5959f6cc2567ebd5eeac1e4a47df53e652a5e Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Wed, 9 Sep 2020 17:21:43 +0200 Subject: [PATCH] flows/importer: fix validate writing to database not being reverted --- passbook/flows/transfer/importer.py | 40 ++++++++++++++++++----------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/passbook/flows/transfer/importer.py b/passbook/flows/transfer/importer.py index 9c03023d9..44911d65d 100644 --- a/passbook/flows/transfer/importer.py +++ b/passbook/flows/transfer/importer.py @@ -1,4 +1,5 @@ """Flow importer""" +from contextlib import contextmanager from copy import deepcopy from json import loads from typing import Any, Dict @@ -9,6 +10,7 @@ from django.apps import apps from django.db import transaction from django.db.models import Model from django.db.models.query_utils import Q +from django.db.utils import IntegrityError from rest_framework.exceptions import ValidationError from rest_framework.serializers import BaseSerializer, Serializer from structlog import BoundLogger, get_logger @@ -26,6 +28,15 @@ from passbook.stages.prompt.models import Prompt ALLOWED_MODELS = (Flow, FlowStageBinding, Stage, Policy, PolicyBinding, Prompt) +@contextmanager +def transaction_rollback(): + """Enters an atomic transaction and always triggers a rollback at the end of the block.""" + atomic = transaction.atomic() + atomic.__enter__() + yield + atomic.__exit__(IntegrityError, None, None) + + class FlowImporter: """Import Flow from json""" @@ -46,11 +57,10 @@ class FlowImporter: def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]: """Replace any value if it is a known primary key of an other object""" + def updater(value) -> Any: if value in self.__pk_map: - self.logger.debug( - "updating reference in entry", value=value - ) + self.logger.debug("updating reference in entry", value=value) return self.__pk_map[value] return value @@ -118,14 +128,15 @@ class FlowImporter: def apply(self) -> bool: """Apply (create/update) flow json, in database transaction""" - sid = transaction.savepoint() - successful = self._apply_models() - if not successful: - self.logger.debug("Reverting changes due to error") - transaction.savepoint_rollback(sid) + try: + with transaction.atomic(): + if not self._apply_models(): + self.logger.debug("Reverting changes due to error") + raise IntegrityError + except IntegrityError: return False - self.logger.debug("Committing changes") - transaction.savepoint_commit(sid) + else: + self.logger.debug("Committing changes") return True def _apply_models(self) -> bool: @@ -154,9 +165,8 @@ class FlowImporter: if self.__import.version != 1: self.logger.warning("Invalid bundle version") return False - sid = transaction.savepoint() - successful = self._apply_models() - if not successful: - self.logger.debug("Flow validation failed") - transaction.savepoint_rollback(sid) + with transaction_rollback(): + successful = self._apply_models() + if not successful: + self.logger.debug("Flow validation failed") return successful