flows/importer: fix validate writing to database not being reverted

This commit is contained in:
Jens Langhammer 2020-09-09 17:21:43 +02:00
parent 18f42a0edf
commit 50a5959f6c
1 changed files with 25 additions and 15 deletions

View File

@ -1,4 +1,5 @@
"""Flow importer""" """Flow importer"""
from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from json import loads from json import loads
from typing import Any, Dict from typing import Any, Dict
@ -9,6 +10,7 @@ from django.apps import apps
from django.db import transaction from django.db import transaction
from django.db.models import Model from django.db.models import Model
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.db.utils import IntegrityError
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer, Serializer from rest_framework.serializers import BaseSerializer, Serializer
from structlog import BoundLogger, get_logger from structlog import BoundLogger, get_logger
@ -26,6 +28,15 @@ from passbook.stages.prompt.models import Prompt
ALLOWED_MODELS = (Flow, FlowStageBinding, Stage, Policy, PolicyBinding, Prompt) ALLOWED_MODELS = (Flow, FlowStageBinding, Stage, Policy, PolicyBinding, Prompt)
@contextmanager
def transaction_rollback():
"""Enters an atomic transaction and always triggers a rollback at the end of the block."""
atomic = transaction.atomic()
atomic.__enter__()
yield
atomic.__exit__(IntegrityError, None, None)
class FlowImporter: class FlowImporter:
"""Import Flow from json""" """Import Flow from json"""
@ -46,11 +57,10 @@ class FlowImporter:
def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]: def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
"""Replace any value if it is a known primary key of an other object""" """Replace any value if it is a known primary key of an other object"""
def updater(value) -> Any: def updater(value) -> Any:
if value in self.__pk_map: if value in self.__pk_map:
self.logger.debug( self.logger.debug("updating reference in entry", value=value)
"updating reference in entry", value=value
)
return self.__pk_map[value] return self.__pk_map[value]
return value return value
@ -118,14 +128,15 @@ class FlowImporter:
def apply(self) -> bool: def apply(self) -> bool:
"""Apply (create/update) flow json, in database transaction""" """Apply (create/update) flow json, in database transaction"""
sid = transaction.savepoint() try:
successful = self._apply_models() with transaction.atomic():
if not successful: if not self._apply_models():
self.logger.debug("Reverting changes due to error") self.logger.debug("Reverting changes due to error")
transaction.savepoint_rollback(sid) raise IntegrityError
except IntegrityError:
return False return False
self.logger.debug("Committing changes") else:
transaction.savepoint_commit(sid) self.logger.debug("Committing changes")
return True return True
def _apply_models(self) -> bool: def _apply_models(self) -> bool:
@ -154,9 +165,8 @@ class FlowImporter:
if self.__import.version != 1: if self.__import.version != 1:
self.logger.warning("Invalid bundle version") self.logger.warning("Invalid bundle version")
return False return False
sid = transaction.savepoint() with transaction_rollback():
successful = self._apply_models() successful = self._apply_models()
if not successful: if not successful:
self.logger.debug("Flow validation failed") self.logger.debug("Flow validation failed")
transaction.savepoint_rollback(sid)
return successful return successful