diff --git a/ereuse_devicehub/resources/lot/views.py b/ereuse_devicehub/resources/lot/views.py index b99e622f..117d3209 100644 --- a/ereuse_devicehub/resources/lot/views.py +++ b/ereuse_devicehub/resources/lot/views.py @@ -1,20 +1,22 @@ import uuid -from sqlalchemy.util import OrderedSet from collections import deque from enum import Enum from typing import Dict, List, Set, Union import marshmallow as ma -from flask import Response, jsonify, request, g -from marshmallow import Schema as MarshmallowSchema, fields as f +from flask import Response, g, jsonify, request +from marshmallow import Schema as MarshmallowSchema +from marshmallow import fields as f from sqlalchemy import or_ +from sqlalchemy.util import OrderedSet from teal.marshmallow import EnumField from teal.resource import View from ereuse_devicehub.db import db +from ereuse_devicehub.inventory.models import Transfer from ereuse_devicehub.query import things_response -from ereuse_devicehub.resources.device.models import Device, Computer -from ereuse_devicehub.resources.action.models import Trade, Confirm, Revoke +from ereuse_devicehub.resources.action.models import Confirm, Revoke, Trade +from ereuse_devicehub.resources.device.models import Computer, Device from ereuse_devicehub.resources.lot.models import Lot, Path @@ -27,6 +29,7 @@ class LotView(View): """Allowed arguments for the ``find`` method (GET collection) endpoint """ + format = EnumField(LotFormat, missing=None) search = f.Str(missing=None) type = f.Str(missing=None) @@ -42,12 +45,26 @@ class LotView(View): return ret def patch(self, id): - patch_schema = self.resource_def.SCHEMA(only=( - 'name', 'description', 'transfer_state', 'receiver_address', 'amount', 'devices', - 'owner_address'), partial=True) + patch_schema = self.resource_def.SCHEMA( + only=( + 'name', + 'description', + 'transfer_state', + 'receiver_address', + 'amount', + 'devices', + 'owner_address', + ), + partial=True, + ) l = request.get_json(schema=patch_schema) lot = Lot.query.filter_by(id=id).one() - device_fields = ['transfer_state', 'receiver_address', 'amount', 'owner_address'] + device_fields = [ + 'transfer_state', + 'receiver_address', + 'amount', + 'owner_address', + ] computers = [x for x in lot.all_devices if isinstance(x, Computer)] for key, value in l.items(): setattr(lot, key, value) @@ -84,7 +101,7 @@ class LotView(View): ret = { 'items': {l['id']: l for l in lots}, 'tree': self.ui_tree(), - 'url': request.path + 'url': request.path, } else: query = Lot.query @@ -95,15 +112,28 @@ class LotView(View): lots = query.paginate(per_page=6 if args['search'] else query.count()) return things_response( self.schema.dump(lots.items, many=True, nested=2), - lots.page, lots.per_page, lots.total, lots.prev_num, lots.next_num + lots.page, + lots.per_page, + lots.total, + lots.prev_num, + lots.next_num, ) return jsonify(ret) def visibility_filter(self, query): - query = query.outerjoin(Trade) \ - .filter(or_(Trade.user_from == g.user, - Trade.user_to == g.user, - Lot.owner_id == g.user.id)) + query = ( + query.outerjoin(Trade) + .outerjoin(Transfer) + .filter( + or_( + Trade.user_from == g.user, + Trade.user_to == g.user, + Lot.owner_id == g.user.id, + Transfer.user_from == g.user, + Transfer.user_to == g.user, + ) + ) + ) return query def type_filter(self, query, args): @@ -111,13 +141,23 @@ class LotView(View): # temporary if lot_type == "temporary": - return query.filter(Lot.trade == None) + return query.filter(Lot.trade == None).filter(Lot.transfer == None) if lot_type == "incoming": - return query.filter(Lot.trade and Trade.user_to == g.user) + return query.filter( + or_( + Lot.trade and Trade.user_to == g.user, + Lot.transfer and Transfer.user_to == g.user, + ) + ).all() if lot_type == "outgoing": - return query.filter(Lot.trade and Trade.user_from == g.user) + return query.filter( + or_( + Lot.trade and Trade.user_from == g.user, + Lot.transfer and Transfer.user_from == g.user, + ) + ).all() return query @@ -152,10 +192,7 @@ class LotView(View): # does lot_id exist already in node? node = next(part for part in nodes if lot_id == part['id']) except StopIteration: - node = { - 'id': lot_id, - 'nodes': [] - } + node = {'id': lot_id, 'nodes': []} nodes.append(node) if path: cls._p(node['nodes'], path) @@ -175,15 +212,17 @@ class LotView(View): class LotBaseChildrenView(View): """Base class for adding / removing children devices and - lots from a lot. - """ + lots from a lot. + """ def __init__(self, definition: 'Resource', **kw) -> None: super().__init__(definition, **kw) self.list_args = self.ListArgs() def get_ids(self) -> Set[uuid.UUID]: - args = self.QUERY_PARSER.parse(self.list_args, request, locations=('querystring',)) + args = self.QUERY_PARSER.parse( + self.list_args, request, locations=('querystring',) + ) return set(args['id']) def get_lot(self, id: uuid.UUID) -> Lot: @@ -247,8 +286,9 @@ class LotDeviceView(LotBaseChildrenView): if not ids: return - devices = set(Device.query.filter(Device.id.in_(ids)).filter( - Device.owner == g.user)) + devices = set( + Device.query.filter(Device.id.in_(ids)).filter(Device.owner == g.user) + ) lot.devices.update(devices) @@ -271,8 +311,9 @@ class LotDeviceView(LotBaseChildrenView): txt = 'This is not your lot' raise ma.ValidationError(txt) - devices = set(Device.query.filter(Device.id.in_(ids)).filter( - Device.owner_id == g.user.id)) + devices = set( + Device.query.filter(Device.id.in_(ids)).filter(Device.owner_id == g.user.id) + ) lot.devices.difference_update(devices) @@ -311,9 +352,7 @@ def delete_from_trade(lot: Lot, devices: List): phantom = lot.trade.user_from phantom_revoke = Revoke( - action=lot.trade, - user=phantom, - devices=set(without_confirms) + action=lot.trade, user=phantom, devices=set(without_confirms) ) db.session.add(phantom_revoke)