diff --git a/ereuse_devicehub/devicehub.py b/ereuse_devicehub/devicehub.py index 915fcde1..ccc80960 100644 --- a/ereuse_devicehub/devicehub.py +++ b/ereuse_devicehub/devicehub.py @@ -1,6 +1,7 @@ from typing import Type from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import event from teal.config import Config as ConfigClass from teal.teal import Teal @@ -8,6 +9,7 @@ from ereuse_devicehub.auth import Auth from ereuse_devicehub.client import Client from ereuse_devicehub.db import db from ereuse_devicehub.dummy.dummy import Dummy +from ereuse_devicehub.resources.device.search import DeviceSearch class Devicehub(Teal): @@ -32,3 +34,13 @@ class Devicehub(Teal): host_matching, subdomain_matching, template_folder, instance_path, instance_relative_config, root_path, Auth) self.dummy = Dummy(self) + self.before_request(self.register_db_events_listeners) + + def register_db_events_listeners(self): + """Registers the SQLAlchemy event listeners.""" + # todo can I make it with a global Session only? + event.listen(db.session, 'before_commit', DeviceSearch.update_modified_devices) + + def _init_db(self): + super()._init_db() + DeviceSearch.set_all_devices_tokens_if_empty(self.db.session) diff --git a/ereuse_devicehub/dummy/dummy.py b/ereuse_devicehub/dummy/dummy.py index 4d871864..88024ae0 100644 --- a/ereuse_devicehub/dummy/dummy.py +++ b/ereuse_devicehub/dummy/dummy.py @@ -101,6 +101,11 @@ class Dummy: inventory, _ = user.get(res=Inventory) assert len(inventory['devices']) assert len(inventory['lots']) + + i, _ = user.get(res=Inventory, query=[('search', 'intel')]) + assert len(i['devices']) == 10 + i, _ = user.get(res=Inventory, query=[('search', 'pc')]) + assert len(i['devices']) == 11 print('⭐ Done.') def user_client(self, email: str, password: str): diff --git a/ereuse_devicehub/resources/device/search.py b/ereuse_devicehub/resources/device/search.py new file mode 100644 index 00000000..3970c7c8 --- /dev/null +++ b/ereuse_devicehub/resources/device/search.py @@ -0,0 +1,122 @@ +import inflection +from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects.postgresql import TSVECTOR +from sqlalchemy.orm import aliased + +from ereuse_devicehub.db import db +from ereuse_devicehub.resources import search +from ereuse_devicehub.resources.agent.models import Organization +from ereuse_devicehub.resources.device.models import Component, Computer, Device +from ereuse_devicehub.resources.event.models import Event, EventWithMultipleDevices, \ + EventWithOneDevice +from ereuse_devicehub.resources.tag.model import Tag + + +class DeviceSearch(db.Model): + """Temporary table that stores full-text device documents. + + It provides methods to auto-run + """ + device_id = db.Column(db.BigInteger, + db.ForeignKey(Device.id, ondelete='CASCADE'), + primary_key=True) + device = db.relationship(Device, primaryjoin=Device.id == device_id) + + properties = db.Column(TSVECTOR, + nullable=False, + index=db.Index('properties gist', + postgresql_using='gist', + postgresql_concurrently=True)) + tags = db.Column(TSVECTOR, index=db.Index('tags gist', + postgresql_using='gist', + postgresql_concurrently=True)) + + __table_args__ = { + 'prefixes': ['UNLOGGED'] # Only for temporal tables, can cause table to empty on turn on + } + + @classmethod + def update_modified_devices(cls, session: db.Session): + """Updates the documents of the devices that are part of a modified + event in the passed-in session. + + This method is registered as a SQLAlchemy + listener in the Devicehub class. + """ + devices_to_update = set() + for event in (e for e in session.new if isinstance(e, Event)): + if isinstance(event, EventWithMultipleDevices): + devices_to_update |= event.devices + elif isinstance(event, EventWithOneDevice): + devices_to_update.add(event.device) + if event.parent: + devices_to_update.add(event.parent) + devices_to_update |= event.components + + # this flush is controversial: + # see https://groups.google.com/forum/#!topic/sqlalchemy/hBzfypgPfYo + # todo probably should replace it with what the solution says + session.flush() + for device in (d for d in devices_to_update if not isinstance(d, Component)): + cls.set_device_tokens(session, device) + + @classmethod + def set_all_devices_tokens_if_empty(cls, session: db.Session): + """Generates the search docs if the table is empty. + + This can happen if Postgres' shut down unexpectedly, as + it deletes unlogged tables as ours. + """ + if not DeviceSearch.query.first(): + for device in Device.query: + if not isinstance(device, Component): + cls.set_device_tokens(session, device) + + @classmethod + def set_device_tokens(cls, session: db.Session, device: Device): + """(Re)Generates the device search tokens.""" + assert not isinstance(device, Component) + + tokens = [ + (inflection.humanize(device.type), search.Weight.B), + (Device.model, search.Weight.B), + (Device.manufacturer, search.Weight.C), + (Device.serial_number, search.Weight.A) + ] + if isinstance(device, Computer): + Comp = aliased(Component) + tokens.extend(( + (db.func.string_agg(Comp.model, ' '), search.Weight.C), + (db.func.string_agg(Comp.manufacturer, ' '), search.Weight.D), + (db.func.string_agg(Comp.serial_number, ' '), search.Weight.B), + (db.func.string_agg(Comp.type, ' '), search.Weight.B), + ('Computer', search.Weight.C), + ('PC', search.Weight.C), + (inflection.humanize(device.chassis.name), search.Weight.B), + )) + + properties = session \ + .query(search.Search.vectorize(*tokens)) \ + .filter(Device.id == device.id) + + if isinstance(device, Computer): + # Join to components + properties = properties \ + .outerjoin(Comp, Computer.components) \ + .group_by(Device.id) + + tags = session.query( + search.Search.vectorize( + (db.func.string_agg(Tag.id, ' '), search.Weight.A), + (db.func.string_agg(Organization.name, ' '), search.Weight.B) + ) + ).filter(Tag.device_id == device.id).join(Tag.org) + + # Note that commit flushes later + # todo see how to get rid of the one_or_none() by embedding those as subqueries + # I don't like this but I want the 'on_conflict_on_update' thingie + device_document = dict(properties=properties.one_or_none(), tags=tags.one_or_none()) + insert = postgresql.insert(DeviceSearch.__table__) \ + .values(device_id=device.id, **device_document) \ + .on_conflict_do_update(constraint='device_search_pkey', set_=device_document) + session.execute(insert) diff --git a/ereuse_devicehub/resources/inventory.py b/ereuse_devicehub/resources/inventory.py index 83fc3d92..de94e281 100644 --- a/ereuse_devicehub/resources/inventory.py +++ b/ereuse_devicehub/resources/inventory.py @@ -4,10 +4,12 @@ from marshmallow import Schema as MarshmallowSchema from marshmallow.fields import Float, Integer, Nested, Str from marshmallow.validate import Range from sqlalchemy import Column -from teal.query import Between, FullTextSearch, ILike, Join, Or, Query, Sort, SortField +from teal.query import Between, ILike, Join, Or, Query, Sort, SortField from teal.resource import Resource, View +from ereuse_devicehub.resources import search from ereuse_devicehub.resources.device.models import Device +from ereuse_devicehub.resources.device.search import DeviceSearch from ereuse_devicehub.resources.event.models import Rate from ereuse_devicehub.resources.lot.models import Lot from ereuse_devicehub.resources.schemas import Thing @@ -54,7 +56,7 @@ class Sorting(Sort): class InventoryView(View): class FindArgs(MarshmallowSchema): - search = FullTextSearch() # todo Develop this. See more at docs/inventory. + search = Str() filter = Nested(Filters, missing=[]) sort = Nested(Sorting, missing=[Device.created.desc()]) page = Integer(validate=Range(min=1), missing=1) @@ -92,10 +94,18 @@ class InventoryView(View): def find(self, args: dict): """See :meth:`.get` above.""" - devices = Device.query \ - .filter(*args['filter']) \ - .order_by(*args['sort']) \ - .paginate(page=args['page'], per_page=30) # type: Pagination + search_p = args.get('search', None) + query = Device.query + if search_p: + properties = DeviceSearch.properties + tags = DeviceSearch.tags + query = query.join(DeviceSearch).filter( + search.Search.match(properties, search_p) | search.Search.match(tags, search_p) + ).order_by( + search.Search.rank(properties, search_p) + search.Search.rank(tags, search_p) + ) + query = query.filter(*args['filter']).order_by(*args['sort']) + devices = query.paginate(page=args['page'], per_page=30) # type: Pagination inventory = { 'devices': app.resources[Device.t].schema.dump(devices.items, many=True, nested=1), 'lots': app.resources[Lot.t].schema.dump(Lot.roots(), many=True, nested=1), diff --git a/ereuse_devicehub/resources/search.py b/ereuse_devicehub/resources/search.py new file mode 100644 index 00000000..50b272b1 --- /dev/null +++ b/ereuse_devicehub/resources/search.py @@ -0,0 +1,55 @@ +"""Full text search module. + +Implements full text search by using Postgre's capabilities and +creating temporary tables containing keywords as ts_vectors. +""" +from enum import Enum +from typing import Tuple + +from ereuse_devicehub.db import db + + +class Weight(Enum): + """TS Rank weight as an Enum.""" + A = 'A' + B = 'B' + C = 'C' + D = 'D' + + +class Search: + """Methods for building queries with Postgre's Full text search. + + Based on `Rachid Belaid's post `_ and + `Code for America's post `. + """ + LANG = 'english' + + @staticmethod + def match(column: db.Column, search: str, lang=LANG): + """Query that matches a TSVECTOR column with search words.""" + return column.op('@@')(db.func.plainto_tsquery(lang, search)) + + @staticmethod + def rank(column: db.Column, search: str, lang=LANG): + """Query that ranks a TSVECTOR column with search words.""" + return db.func.ts_rank(column, db.func.plainto_tsquery(lang, search)) + + @staticmethod + def _vectorize(col: db.Column, weight: Weight = Weight.D, lang=LANG): + return db.func.setweight(db.func.to_tsvector(lang, db.func.coalesce(col, '')), weight.name) + + @classmethod + def vectorize(cls, *cols_with_weights: Tuple[db.Column, Weight], lang=LANG): + """Produces a query that takes one ore more columns and their + respective weights, and generates one big TSVECTOR. + + This method takes care of `null` column values. + """ + first, rest = cols_with_weights[0], cols_with_weights[1:] + tokens = cls._vectorize(*first, lang=lang) + for unit in rest: + tokens = tokens.concat(cls._vectorize(*unit, lang=lang)) + return tokens diff --git a/setup.py b/setup.py index 8ba88e15..4edb67a6 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ setup( long_description=long_description, long_description_content_type='text/markdown', install_requires=[ - 'teal>=0.2.0a16', # teal always first + 'teal>=0.2.0a17', # teal always first 'click', 'click-spinner', 'ereuse-rate==0.0.2', diff --git a/tests/test_device.py b/tests/test_device.py index b48e480b..fdf45a00 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -16,13 +16,17 @@ from ereuse_devicehub.resources.device.exceptions import NeedsId from ereuse_devicehub.resources.device.models import Component, ComputerMonitor, Desktop, Device, \ GraphicCard, Laptop, Motherboard, NetworkAdapter from ereuse_devicehub.resources.device.schemas import Device as DeviceS +from ereuse_devicehub.resources.device.search import DeviceSearch from ereuse_devicehub.resources.device.sync import MismatchBetweenTags, MismatchBetweenTagsAndHid, \ Sync from ereuse_devicehub.resources.enums import ComputerChassis, DisplayTech +from ereuse_devicehub.resources.event import models as m from ereuse_devicehub.resources.event.models import Remove, Test +from ereuse_devicehub.resources.inventory import Inventory from ereuse_devicehub.resources.tag.model import Tag from ereuse_devicehub.resources.user import User from tests import conftest +from tests.conftest import file @pytest.mark.usefixtures(conftest.app_context.__name__) @@ -418,8 +422,9 @@ def test_get_devices(app: Devicehub, user: UserClient): db.session.commit() devices, _ = user.get(res=Device) assert tuple(d['id'] for d in devices) == (1, 2, 3, 4, 5) - assert tuple(d['type'] for d in devices) == ('Desktop', 'Desktop', 'Laptop', - 'NetworkAdapter', 'GraphicCard') + assert tuple(d['type'] for d in devices) == ( + 'Desktop', 'Desktop', 'Laptop', 'NetworkAdapter', 'GraphicCard' + ) @pytest.mark.usefixtures(conftest.app_context.__name__) @@ -448,3 +453,17 @@ def test_mobile_imei(): @pytest.mark.xfail(reason='Make test') def test_computer_with_display(): pass + + +def test_device_search_all_devices_token_if_empty(app: Devicehub, user: UserClient): + """Ensures DeviceSearch can regenerate itself when the table is empty.""" + user.post(file('basic.snapshot'), res=m.Snapshot) + with app.app_context(): + app.db.session.execute('TRUNCATE TABLE {}'.format(DeviceSearch.__table__.name)) + app.db.session.commit() + i, _ = user.get(res=Inventory, query=[('search', 'Desktop')]) + assert not len(i['devices']) + with app.app_context(): + DeviceSearch.set_all_devices_tokens_if_empty(app.db.session) + i, _ = user.get(res=Inventory, query=[('search', 'Desktop')]) + assert not len(i['devices']) diff --git a/tests/test_inventory.py b/tests/test_inventory.py index d9559bf5..6eb4c40b 100644 --- a/tests/test_inventory.py +++ b/tests/test_inventory.py @@ -9,6 +9,7 @@ from ereuse_devicehub.resources.enums import ComputerChassis from ereuse_devicehub.resources.event.models import Snapshot from ereuse_devicehub.resources.inventory import Filters, Inventory, Sorting from tests import conftest +from tests.conftest import file @pytest.mark.usefixtures(conftest.app_context.__name__) @@ -53,7 +54,7 @@ def test_inventory_sort(): @pytest.fixture() def inventory_query_dummy(app: Devicehub): with app.app_context(): - db.session.add_all(( # The order matters ;-) + devices = ( # The order matters ;-) Desktop(serial_number='s1', model='ml1', manufacturer='mr1', @@ -67,7 +68,10 @@ def inventory_query_dummy(app: Devicehub): manufacturer='mr2', chassis=ComputerChassis.Microtower), SolidStateDrive(serial_number='s4', model='ml4', manufacturer='mr4') - )) + ) + devices[-1].parent = devices[0] # s4 in s1 + db.session.add_all(devices) + db.session.commit() @@ -107,3 +111,36 @@ def test_inventory_query(user: UserClient): @pytest.mark.xfail(reason='Functionality not yet developed.') def test_inventory_lots_query(user: UserClient): pass + + +def test_inventory_query_search(user: UserClient): + # todo improve + user.post(file('basic.snapshot'), res=Snapshot) + user.post(file('computer-monitor.snapshot'), res=Snapshot) + user.post(file('real-eee-1001pxd.snapshot.11'), res=Snapshot) + i, _ = user.get(res=Inventory, query=[('search', 'desktop')]) + assert i['devices'][0]['id'] == 1 + i, _ = user.get(res=Inventory, query=[('search', 'intel')]) + assert len(i['devices']) == 1 + + +@pytest.mark.xfail(reason='No dictionary yet that knows asustek = asus') +def test_inventory_query_search_synonyms_asus(user: UserClient): + user.post(file('real-eee-1001pxd.snapshot.11'), res=Snapshot) + i, _ = user.get(res=Inventory, query=[('search', 'asustek')]) + assert len(i['devices']) == 1 + i, _ = user.get(res=Inventory, query=[('search', 'asus')]) + assert len(i['devices']) == 1 + + +@pytest.mark.xfail(reason='No dictionary yet that knows hp = hewlett packard') +def test_inventory_query_search_synonyms_intel(user: UserClient): + s = file('real-hp.snapshot.11') + s['device']['model'] = 'foo' # The model had the word 'HP' in it + user.post(s, res=Snapshot) + i, _ = user.get(res=Inventory, query=[('search', 'hewlett packard')]) + assert len(i['devices']) == 1 + i, _ = user.get(res=Inventory, query=[('search', 'hewlett')]) + assert len(i['devices']) == 1 + i, _ = user.get(res=Inventory, query=[('search', 'hp')]) + assert len(i['devices']) == 1