diff --git a/ereuse_devicehub/devicehub.py b/ereuse_devicehub/devicehub.py
index 915fcde1..ccc80960 100644
--- a/ereuse_devicehub/devicehub.py
+++ b/ereuse_devicehub/devicehub.py
@@ -1,6 +1,7 @@
from typing import Type
from flask_sqlalchemy import SQLAlchemy
+from sqlalchemy import event
from teal.config import Config as ConfigClass
from teal.teal import Teal
@@ -8,6 +9,7 @@ from ereuse_devicehub.auth import Auth
from ereuse_devicehub.client import Client
from ereuse_devicehub.db import db
from ereuse_devicehub.dummy.dummy import Dummy
+from ereuse_devicehub.resources.device.search import DeviceSearch
class Devicehub(Teal):
@@ -32,3 +34,13 @@ class Devicehub(Teal):
host_matching, subdomain_matching, template_folder, instance_path,
instance_relative_config, root_path, Auth)
self.dummy = Dummy(self)
+ self.before_request(self.register_db_events_listeners)
+
+ def register_db_events_listeners(self):
+ """Registers the SQLAlchemy event listeners."""
+ # todo can I make it with a global Session only?
+ event.listen(db.session, 'before_commit', DeviceSearch.update_modified_devices)
+
+ def _init_db(self):
+ super()._init_db()
+ DeviceSearch.set_all_devices_tokens_if_empty(self.db.session)
diff --git a/ereuse_devicehub/dummy/dummy.py b/ereuse_devicehub/dummy/dummy.py
index 4d871864..88024ae0 100644
--- a/ereuse_devicehub/dummy/dummy.py
+++ b/ereuse_devicehub/dummy/dummy.py
@@ -101,6 +101,11 @@ class Dummy:
inventory, _ = user.get(res=Inventory)
assert len(inventory['devices'])
assert len(inventory['lots'])
+
+ i, _ = user.get(res=Inventory, query=[('search', 'intel')])
+ assert len(i['devices']) == 10
+ i, _ = user.get(res=Inventory, query=[('search', 'pc')])
+ assert len(i['devices']) == 11
print('⭐ Done.')
def user_client(self, email: str, password: str):
diff --git a/ereuse_devicehub/resources/device/search.py b/ereuse_devicehub/resources/device/search.py
new file mode 100644
index 00000000..3970c7c8
--- /dev/null
+++ b/ereuse_devicehub/resources/device/search.py
@@ -0,0 +1,122 @@
+import inflection
+from sqlalchemy.dialects import postgresql
+from sqlalchemy.dialects.postgresql import TSVECTOR
+from sqlalchemy.orm import aliased
+
+from ereuse_devicehub.db import db
+from ereuse_devicehub.resources import search
+from ereuse_devicehub.resources.agent.models import Organization
+from ereuse_devicehub.resources.device.models import Component, Computer, Device
+from ereuse_devicehub.resources.event.models import Event, EventWithMultipleDevices, \
+ EventWithOneDevice
+from ereuse_devicehub.resources.tag.model import Tag
+
+
+class DeviceSearch(db.Model):
+ """Temporary table that stores full-text device documents.
+
+ It provides methods to auto-run
+ """
+ device_id = db.Column(db.BigInteger,
+ db.ForeignKey(Device.id, ondelete='CASCADE'),
+ primary_key=True)
+ device = db.relationship(Device, primaryjoin=Device.id == device_id)
+
+ properties = db.Column(TSVECTOR,
+ nullable=False,
+ index=db.Index('properties gist',
+ postgresql_using='gist',
+ postgresql_concurrently=True))
+ tags = db.Column(TSVECTOR, index=db.Index('tags gist',
+ postgresql_using='gist',
+ postgresql_concurrently=True))
+
+ __table_args__ = {
+ 'prefixes': ['UNLOGGED'] # Only for temporal tables, can cause table to empty on turn on
+ }
+
+ @classmethod
+ def update_modified_devices(cls, session: db.Session):
+ """Updates the documents of the devices that are part of a modified
+ event in the passed-in session.
+
+ This method is registered as a SQLAlchemy
+ listener in the Devicehub class.
+ """
+ devices_to_update = set()
+ for event in (e for e in session.new if isinstance(e, Event)):
+ if isinstance(event, EventWithMultipleDevices):
+ devices_to_update |= event.devices
+ elif isinstance(event, EventWithOneDevice):
+ devices_to_update.add(event.device)
+ if event.parent:
+ devices_to_update.add(event.parent)
+ devices_to_update |= event.components
+
+ # this flush is controversial:
+ # see https://groups.google.com/forum/#!topic/sqlalchemy/hBzfypgPfYo
+ # todo probably should replace it with what the solution says
+ session.flush()
+ for device in (d for d in devices_to_update if not isinstance(d, Component)):
+ cls.set_device_tokens(session, device)
+
+ @classmethod
+ def set_all_devices_tokens_if_empty(cls, session: db.Session):
+ """Generates the search docs if the table is empty.
+
+ This can happen if Postgres' shut down unexpectedly, as
+ it deletes unlogged tables as ours.
+ """
+ if not DeviceSearch.query.first():
+ for device in Device.query:
+ if not isinstance(device, Component):
+ cls.set_device_tokens(session, device)
+
+ @classmethod
+ def set_device_tokens(cls, session: db.Session, device: Device):
+ """(Re)Generates the device search tokens."""
+ assert not isinstance(device, Component)
+
+ tokens = [
+ (inflection.humanize(device.type), search.Weight.B),
+ (Device.model, search.Weight.B),
+ (Device.manufacturer, search.Weight.C),
+ (Device.serial_number, search.Weight.A)
+ ]
+ if isinstance(device, Computer):
+ Comp = aliased(Component)
+ tokens.extend((
+ (db.func.string_agg(Comp.model, ' '), search.Weight.C),
+ (db.func.string_agg(Comp.manufacturer, ' '), search.Weight.D),
+ (db.func.string_agg(Comp.serial_number, ' '), search.Weight.B),
+ (db.func.string_agg(Comp.type, ' '), search.Weight.B),
+ ('Computer', search.Weight.C),
+ ('PC', search.Weight.C),
+ (inflection.humanize(device.chassis.name), search.Weight.B),
+ ))
+
+ properties = session \
+ .query(search.Search.vectorize(*tokens)) \
+ .filter(Device.id == device.id)
+
+ if isinstance(device, Computer):
+ # Join to components
+ properties = properties \
+ .outerjoin(Comp, Computer.components) \
+ .group_by(Device.id)
+
+ tags = session.query(
+ search.Search.vectorize(
+ (db.func.string_agg(Tag.id, ' '), search.Weight.A),
+ (db.func.string_agg(Organization.name, ' '), search.Weight.B)
+ )
+ ).filter(Tag.device_id == device.id).join(Tag.org)
+
+ # Note that commit flushes later
+ # todo see how to get rid of the one_or_none() by embedding those as subqueries
+ # I don't like this but I want the 'on_conflict_on_update' thingie
+ device_document = dict(properties=properties.one_or_none(), tags=tags.one_or_none())
+ insert = postgresql.insert(DeviceSearch.__table__) \
+ .values(device_id=device.id, **device_document) \
+ .on_conflict_do_update(constraint='device_search_pkey', set_=device_document)
+ session.execute(insert)
diff --git a/ereuse_devicehub/resources/inventory.py b/ereuse_devicehub/resources/inventory.py
index 83fc3d92..de94e281 100644
--- a/ereuse_devicehub/resources/inventory.py
+++ b/ereuse_devicehub/resources/inventory.py
@@ -4,10 +4,12 @@ from marshmallow import Schema as MarshmallowSchema
from marshmallow.fields import Float, Integer, Nested, Str
from marshmallow.validate import Range
from sqlalchemy import Column
-from teal.query import Between, FullTextSearch, ILike, Join, Or, Query, Sort, SortField
+from teal.query import Between, ILike, Join, Or, Query, Sort, SortField
from teal.resource import Resource, View
+from ereuse_devicehub.resources import search
from ereuse_devicehub.resources.device.models import Device
+from ereuse_devicehub.resources.device.search import DeviceSearch
from ereuse_devicehub.resources.event.models import Rate
from ereuse_devicehub.resources.lot.models import Lot
from ereuse_devicehub.resources.schemas import Thing
@@ -54,7 +56,7 @@ class Sorting(Sort):
class InventoryView(View):
class FindArgs(MarshmallowSchema):
- search = FullTextSearch() # todo Develop this. See more at docs/inventory.
+ search = Str()
filter = Nested(Filters, missing=[])
sort = Nested(Sorting, missing=[Device.created.desc()])
page = Integer(validate=Range(min=1), missing=1)
@@ -92,10 +94,18 @@ class InventoryView(View):
def find(self, args: dict):
"""See :meth:`.get` above."""
- devices = Device.query \
- .filter(*args['filter']) \
- .order_by(*args['sort']) \
- .paginate(page=args['page'], per_page=30) # type: Pagination
+ search_p = args.get('search', None)
+ query = Device.query
+ if search_p:
+ properties = DeviceSearch.properties
+ tags = DeviceSearch.tags
+ query = query.join(DeviceSearch).filter(
+ search.Search.match(properties, search_p) | search.Search.match(tags, search_p)
+ ).order_by(
+ search.Search.rank(properties, search_p) + search.Search.rank(tags, search_p)
+ )
+ query = query.filter(*args['filter']).order_by(*args['sort'])
+ devices = query.paginate(page=args['page'], per_page=30) # type: Pagination
inventory = {
'devices': app.resources[Device.t].schema.dump(devices.items, many=True, nested=1),
'lots': app.resources[Lot.t].schema.dump(Lot.roots(), many=True, nested=1),
diff --git a/ereuse_devicehub/resources/search.py b/ereuse_devicehub/resources/search.py
new file mode 100644
index 00000000..50b272b1
--- /dev/null
+++ b/ereuse_devicehub/resources/search.py
@@ -0,0 +1,55 @@
+"""Full text search module.
+
+Implements full text search by using Postgre's capabilities and
+creating temporary tables containing keywords as ts_vectors.
+"""
+from enum import Enum
+from typing import Tuple
+
+from ereuse_devicehub.db import db
+
+
+class Weight(Enum):
+ """TS Rank weight as an Enum."""
+ A = 'A'
+ B = 'B'
+ C = 'C'
+ D = 'D'
+
+
+class Search:
+ """Methods for building queries with Postgre's Full text search.
+
+ Based on `Rachid Belaid's post `_ and
+ `Code for America's post `.
+ """
+ LANG = 'english'
+
+ @staticmethod
+ def match(column: db.Column, search: str, lang=LANG):
+ """Query that matches a TSVECTOR column with search words."""
+ return column.op('@@')(db.func.plainto_tsquery(lang, search))
+
+ @staticmethod
+ def rank(column: db.Column, search: str, lang=LANG):
+ """Query that ranks a TSVECTOR column with search words."""
+ return db.func.ts_rank(column, db.func.plainto_tsquery(lang, search))
+
+ @staticmethod
+ def _vectorize(col: db.Column, weight: Weight = Weight.D, lang=LANG):
+ return db.func.setweight(db.func.to_tsvector(lang, db.func.coalesce(col, '')), weight.name)
+
+ @classmethod
+ def vectorize(cls, *cols_with_weights: Tuple[db.Column, Weight], lang=LANG):
+ """Produces a query that takes one ore more columns and their
+ respective weights, and generates one big TSVECTOR.
+
+ This method takes care of `null` column values.
+ """
+ first, rest = cols_with_weights[0], cols_with_weights[1:]
+ tokens = cls._vectorize(*first, lang=lang)
+ for unit in rest:
+ tokens = tokens.concat(cls._vectorize(*unit, lang=lang))
+ return tokens
diff --git a/setup.py b/setup.py
index 8ba88e15..4edb67a6 100644
--- a/setup.py
+++ b/setup.py
@@ -34,7 +34,7 @@ setup(
long_description=long_description,
long_description_content_type='text/markdown',
install_requires=[
- 'teal>=0.2.0a16', # teal always first
+ 'teal>=0.2.0a17', # teal always first
'click',
'click-spinner',
'ereuse-rate==0.0.2',
diff --git a/tests/test_device.py b/tests/test_device.py
index b48e480b..fdf45a00 100644
--- a/tests/test_device.py
+++ b/tests/test_device.py
@@ -16,13 +16,17 @@ from ereuse_devicehub.resources.device.exceptions import NeedsId
from ereuse_devicehub.resources.device.models import Component, ComputerMonitor, Desktop, Device, \
GraphicCard, Laptop, Motherboard, NetworkAdapter
from ereuse_devicehub.resources.device.schemas import Device as DeviceS
+from ereuse_devicehub.resources.device.search import DeviceSearch
from ereuse_devicehub.resources.device.sync import MismatchBetweenTags, MismatchBetweenTagsAndHid, \
Sync
from ereuse_devicehub.resources.enums import ComputerChassis, DisplayTech
+from ereuse_devicehub.resources.event import models as m
from ereuse_devicehub.resources.event.models import Remove, Test
+from ereuse_devicehub.resources.inventory import Inventory
from ereuse_devicehub.resources.tag.model import Tag
from ereuse_devicehub.resources.user import User
from tests import conftest
+from tests.conftest import file
@pytest.mark.usefixtures(conftest.app_context.__name__)
@@ -418,8 +422,9 @@ def test_get_devices(app: Devicehub, user: UserClient):
db.session.commit()
devices, _ = user.get(res=Device)
assert tuple(d['id'] for d in devices) == (1, 2, 3, 4, 5)
- assert tuple(d['type'] for d in devices) == ('Desktop', 'Desktop', 'Laptop',
- 'NetworkAdapter', 'GraphicCard')
+ assert tuple(d['type'] for d in devices) == (
+ 'Desktop', 'Desktop', 'Laptop', 'NetworkAdapter', 'GraphicCard'
+ )
@pytest.mark.usefixtures(conftest.app_context.__name__)
@@ -448,3 +453,17 @@ def test_mobile_imei():
@pytest.mark.xfail(reason='Make test')
def test_computer_with_display():
pass
+
+
+def test_device_search_all_devices_token_if_empty(app: Devicehub, user: UserClient):
+ """Ensures DeviceSearch can regenerate itself when the table is empty."""
+ user.post(file('basic.snapshot'), res=m.Snapshot)
+ with app.app_context():
+ app.db.session.execute('TRUNCATE TABLE {}'.format(DeviceSearch.__table__.name))
+ app.db.session.commit()
+ i, _ = user.get(res=Inventory, query=[('search', 'Desktop')])
+ assert not len(i['devices'])
+ with app.app_context():
+ DeviceSearch.set_all_devices_tokens_if_empty(app.db.session)
+ i, _ = user.get(res=Inventory, query=[('search', 'Desktop')])
+ assert not len(i['devices'])
diff --git a/tests/test_inventory.py b/tests/test_inventory.py
index d9559bf5..6eb4c40b 100644
--- a/tests/test_inventory.py
+++ b/tests/test_inventory.py
@@ -9,6 +9,7 @@ from ereuse_devicehub.resources.enums import ComputerChassis
from ereuse_devicehub.resources.event.models import Snapshot
from ereuse_devicehub.resources.inventory import Filters, Inventory, Sorting
from tests import conftest
+from tests.conftest import file
@pytest.mark.usefixtures(conftest.app_context.__name__)
@@ -53,7 +54,7 @@ def test_inventory_sort():
@pytest.fixture()
def inventory_query_dummy(app: Devicehub):
with app.app_context():
- db.session.add_all(( # The order matters ;-)
+ devices = ( # The order matters ;-)
Desktop(serial_number='s1',
model='ml1',
manufacturer='mr1',
@@ -67,7 +68,10 @@ def inventory_query_dummy(app: Devicehub):
manufacturer='mr2',
chassis=ComputerChassis.Microtower),
SolidStateDrive(serial_number='s4', model='ml4', manufacturer='mr4')
- ))
+ )
+ devices[-1].parent = devices[0] # s4 in s1
+ db.session.add_all(devices)
+
db.session.commit()
@@ -107,3 +111,36 @@ def test_inventory_query(user: UserClient):
@pytest.mark.xfail(reason='Functionality not yet developed.')
def test_inventory_lots_query(user: UserClient):
pass
+
+
+def test_inventory_query_search(user: UserClient):
+ # todo improve
+ user.post(file('basic.snapshot'), res=Snapshot)
+ user.post(file('computer-monitor.snapshot'), res=Snapshot)
+ user.post(file('real-eee-1001pxd.snapshot.11'), res=Snapshot)
+ i, _ = user.get(res=Inventory, query=[('search', 'desktop')])
+ assert i['devices'][0]['id'] == 1
+ i, _ = user.get(res=Inventory, query=[('search', 'intel')])
+ assert len(i['devices']) == 1
+
+
+@pytest.mark.xfail(reason='No dictionary yet that knows asustek = asus')
+def test_inventory_query_search_synonyms_asus(user: UserClient):
+ user.post(file('real-eee-1001pxd.snapshot.11'), res=Snapshot)
+ i, _ = user.get(res=Inventory, query=[('search', 'asustek')])
+ assert len(i['devices']) == 1
+ i, _ = user.get(res=Inventory, query=[('search', 'asus')])
+ assert len(i['devices']) == 1
+
+
+@pytest.mark.xfail(reason='No dictionary yet that knows hp = hewlett packard')
+def test_inventory_query_search_synonyms_intel(user: UserClient):
+ s = file('real-hp.snapshot.11')
+ s['device']['model'] = 'foo' # The model had the word 'HP' in it
+ user.post(s, res=Snapshot)
+ i, _ = user.get(res=Inventory, query=[('search', 'hewlett packard')])
+ assert len(i['devices']) == 1
+ i, _ = user.get(res=Inventory, query=[('search', 'hewlett')])
+ assert len(i['devices']) == 1
+ i, _ = user.get(res=Inventory, query=[('search', 'hp')])
+ assert len(i['devices']) == 1