added FAISS Searcher
This commit is contained in:
parent
2e1796c18a
commit
ad4a95e000
|
@ -0,0 +1 @@
|
|||
/venv/
|
|
@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
nano \
|
||||
unzip
|
||||
|
||||
RUN pip install numpy tornado flask-restful pillow numpy matplotlib tqdm scikit-learn h5py requests
|
||||
RUN pip install numpy tornado flask-restful pillow numpy matplotlib tqdm scikit-learn h5py requests faiss-cpu==1.7.2
|
||||
ADD . /workspace
|
||||
|
||||
WORKDIR /workspace
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import h5py
|
||||
import numpy as np
|
||||
import beniculturaliSettings as settings
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import beniculturaliSettings as settings
|
|||
from BeniCulturaliRescorer import BeniCulturaliRescorer
|
||||
from BeniCulturaliSearchEngine import BeniCulturaliSearchEngine
|
||||
import FeatureExtractor as fe
|
||||
import ORBExtractor as lf
|
||||
#import ORBExtractor as lf
|
||||
|
||||
|
||||
class BeniCulturaliSearcher:
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
import numpy as np
|
||||
import beniculturaliSettings as settings
|
||||
import faiss
|
||||
|
||||
|
||||
class FAISSSearchEngine:
|
||||
|
||||
|
||||
def __init__(self):
|
||||
#self.dataset = h5py.File(settings.dataset_file, 'r')['rmac'][...]
|
||||
|
||||
#np.save('/media/Data/data/beni_culturali/deploy/dataset', self.dataset)
|
||||
self.descs = np.load(settings.DATASET)
|
||||
#self.desc1 = np.load(settings.DATASET1)
|
||||
#self.desc2 = np.load(settings.DATASET2)
|
||||
|
||||
#self.descs = (self.desc1 + self.desc2) / 2
|
||||
#self.descs /= np.linalg.norm(self.descs, axis=1, keepdims=True)
|
||||
self.ids = np.loadtxt(settings.DATASET_IDS, dtype=str).tolist()
|
||||
|
||||
# create an index with inner product similarity
|
||||
dim = 2048 # dimensionality of the features
|
||||
metric = faiss.METRIC_INNER_PRODUCT
|
||||
self.index = faiss.index_factory(dim, 'Flat', metric)
|
||||
|
||||
# add the vectors to the index
|
||||
self.index.add(self.descs) # my_database is a numpy array of shape N x dim, where N is the number of vectors to index
|
||||
|
||||
|
||||
def get_id(self, idx):
|
||||
return self.ids[idx]
|
||||
|
||||
|
||||
def add(self, desc, id):
|
||||
self.ids.append(id)
|
||||
self.descs = np.vstack((self.descs, desc))
|
||||
self.save()
|
||||
|
||||
|
||||
def remove(self, id):
|
||||
idx = self.ids.index(id)
|
||||
del self.ids[idx]
|
||||
self.descs = np.delete(self.descs, idx, axis=0)
|
||||
|
||||
|
||||
def search_by_id(self, query_id, k=10):
|
||||
query_idx = self.ids.index(query_id)
|
||||
return self.search_by_img(self.descs[query_idx], k)
|
||||
|
||||
def search_by_img(self, query, k=10):
|
||||
print('----------query features-------')
|
||||
print(query)
|
||||
queries = np.reshape(query, (-1, 2048))
|
||||
print(queries)
|
||||
scores, indexes = self.index.search(queries, k)
|
||||
#dot_product = np.dot(self.descs, query)
|
||||
#idx = dot_product.argsort()[::-1][:k]
|
||||
res = []
|
||||
for (i,j) in zip(indexes[0], scores[0]):
|
||||
res.append((self.ids[i], round(float(j), 3)))
|
||||
return res
|
||||
|
||||
def save(self, is_backup=False):
|
||||
descs_file = settings.DATASET
|
||||
ids_file = settings.DATASET_IDS
|
||||
|
||||
if is_backup:
|
||||
descs_file += '.bak'
|
||||
ids_file += '.bak'
|
||||
|
||||
np.save(descs_file, self.descs)
|
||||
np.savetxt(ids_file, self.ids, fmt='%s')
|
|
@ -0,0 +1,68 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import pickle as pickle
|
||||
|
||||
import LFUtilities
|
||||
import beniculturaliSettings as settings
|
||||
from BeniCulturaliRescorer import BeniCulturaliRescorer
|
||||
from FAISSSearchEngine import FAISSSearchEngine
|
||||
import FeatureExtractor as fe
|
||||
import ORBExtractor as lf
|
||||
|
||||
|
||||
class Searcher:
|
||||
K_REORDERING = 15
|
||||
|
||||
def __init__(self):
|
||||
# self.dataset = h5py.File(settings.dataset_file, 'r')['rmac'][...]
|
||||
|
||||
# np.save('/media/Data/data/beni_culturali/deploy/dataset', self.dataset)
|
||||
self.search_engine = FAISSSearchEngine()
|
||||
#self.rescorer = BeniCulturaliRescorer()
|
||||
|
||||
def get_id(self, idx):
|
||||
return self.search_engine.get_id(idx)
|
||||
|
||||
def add(self, img_file, id):
|
||||
self.save(True)
|
||||
|
||||
desc = fe.extract(img_file)
|
||||
#orb = lf.extract(img_file)
|
||||
self.search_engine.add(desc, id)
|
||||
#self.rescorer.add(orb)
|
||||
|
||||
self.save()
|
||||
print('added ' + id)
|
||||
|
||||
def remove(self, id):
|
||||
self.save(True)
|
||||
self.search_engine.remove(id)
|
||||
#self.rescorer.remove(idx)
|
||||
self.save()
|
||||
print('removed ' + id)
|
||||
|
||||
def search_by_id(self, query_id, k=10, rescorer=False):
|
||||
kq = k
|
||||
if rescorer:
|
||||
kq = self.K_REORDERING
|
||||
res = self.search_engine.search_by_id(query_id, kq)
|
||||
# if rescorer:
|
||||
# res_lf = self.rescorer.rescore_by_id(query_id, res)
|
||||
# res = res_lf if res_lf else res[:k]
|
||||
return res
|
||||
|
||||
def search_by_img(self, query_img, k=10, rescorer=False):
|
||||
kq = k
|
||||
if rescorer:
|
||||
kq = self.K_REORDERING
|
||||
query_desc = fe.extract(query_img)
|
||||
res = self.search_engine.search_by_img(query_desc, kq)
|
||||
#if rescorer:
|
||||
# query_lf = lf.extract(query_img)
|
||||
# res_lf = self.rescorer.rescore_by_img(query_lf, res)
|
||||
# res = res_lf if res_lf else res[:k]
|
||||
return res
|
||||
|
||||
def save(self, is_backup=False):
|
||||
self.search_engine.save(is_backup)
|
||||
#self.rescorer.save(is_backup)
|
|
@ -1,5 +1,3 @@
|
|||
from re import split
|
||||
|
||||
from flask import Flask, request, redirect, url_for, flash, render_template, send_from_directory, abort
|
||||
from random import randint
|
||||
import cv2
|
||||
|
@ -9,13 +7,12 @@ import json
|
|||
|
||||
import urllib
|
||||
|
||||
from BeniCulturaliSearcher import BeniCulturaliSearcher
|
||||
#from BeniCulturaliSearcher import BeniCulturaliSearcher
|
||||
from Searcher import Searcher
|
||||
from BeniCulturaliSearchEngine import BeniCulturaliSearchEngine
|
||||
import beniculturaliSettings as settings
|
||||
import uuid
|
||||
import requests
|
||||
import os, os.path
|
||||
from PIL import Image
|
||||
import tornado.wsgi
|
||||
import tornado.httpserver
|
||||
import argparse
|
||||
|
@ -29,7 +26,7 @@ def api_root():
|
|||
print('index_with_randoms.html')
|
||||
random_ids = []
|
||||
for i in range(0, 15):
|
||||
random_ids.append(searcher.get_id(randint(0, 3000)))
|
||||
random_ids.append(searcher.get_id(randint(0, 30)))
|
||||
return render_template('index_with_randoms.html', random_ids=random_ids)
|
||||
|
||||
|
||||
|
@ -188,7 +185,7 @@ def start_from_terminal(app):
|
|||
args = parser.parse_args()
|
||||
settings.load_setting(args.conf)
|
||||
global searcher
|
||||
searcher = BeniCulturaliSearcher()
|
||||
searcher = Searcher()
|
||||
|
||||
#if args.debug:
|
||||
# app.run(debug=True, host='0.0.0.0', port=settings.port)
|
||||
|
|
Loading…
Reference in New Issue