added FAISS Searcher
This commit is contained in:
parent
2e1796c18a
commit
ad4a95e000
|
@ -0,0 +1 @@
|
||||||
|
/venv/
|
|
@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
nano \
|
nano \
|
||||||
unzip
|
unzip
|
||||||
|
|
||||||
RUN pip install numpy tornado flask-restful pillow numpy matplotlib tqdm scikit-learn h5py requests
|
RUN pip install numpy tornado flask-restful pillow numpy matplotlib tqdm scikit-learn h5py requests faiss-cpu==1.7.2
|
||||||
ADD . /workspace
|
ADD . /workspace
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import h5py
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import beniculturaliSettings as settings
|
import beniculturaliSettings as settings
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import beniculturaliSettings as settings
|
||||||
from BeniCulturaliRescorer import BeniCulturaliRescorer
|
from BeniCulturaliRescorer import BeniCulturaliRescorer
|
||||||
from BeniCulturaliSearchEngine import BeniCulturaliSearchEngine
|
from BeniCulturaliSearchEngine import BeniCulturaliSearchEngine
|
||||||
import FeatureExtractor as fe
|
import FeatureExtractor as fe
|
||||||
import ORBExtractor as lf
|
#import ORBExtractor as lf
|
||||||
|
|
||||||
|
|
||||||
class BeniCulturaliSearcher:
|
class BeniCulturaliSearcher:
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
import numpy as np
|
||||||
|
import beniculturaliSettings as settings
|
||||||
|
import faiss
|
||||||
|
|
||||||
|
|
||||||
|
class FAISSSearchEngine:
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
#self.dataset = h5py.File(settings.dataset_file, 'r')['rmac'][...]
|
||||||
|
|
||||||
|
#np.save('/media/Data/data/beni_culturali/deploy/dataset', self.dataset)
|
||||||
|
self.descs = np.load(settings.DATASET)
|
||||||
|
#self.desc1 = np.load(settings.DATASET1)
|
||||||
|
#self.desc2 = np.load(settings.DATASET2)
|
||||||
|
|
||||||
|
#self.descs = (self.desc1 + self.desc2) / 2
|
||||||
|
#self.descs /= np.linalg.norm(self.descs, axis=1, keepdims=True)
|
||||||
|
self.ids = np.loadtxt(settings.DATASET_IDS, dtype=str).tolist()
|
||||||
|
|
||||||
|
# create an index with inner product similarity
|
||||||
|
dim = 2048 # dimensionality of the features
|
||||||
|
metric = faiss.METRIC_INNER_PRODUCT
|
||||||
|
self.index = faiss.index_factory(dim, 'Flat', metric)
|
||||||
|
|
||||||
|
# add the vectors to the index
|
||||||
|
self.index.add(self.descs) # my_database is a numpy array of shape N x dim, where N is the number of vectors to index
|
||||||
|
|
||||||
|
|
||||||
|
def get_id(self, idx):
|
||||||
|
return self.ids[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def add(self, desc, id):
|
||||||
|
self.ids.append(id)
|
||||||
|
self.descs = np.vstack((self.descs, desc))
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
|
||||||
|
def remove(self, id):
|
||||||
|
idx = self.ids.index(id)
|
||||||
|
del self.ids[idx]
|
||||||
|
self.descs = np.delete(self.descs, idx, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def search_by_id(self, query_id, k=10):
|
||||||
|
query_idx = self.ids.index(query_id)
|
||||||
|
return self.search_by_img(self.descs[query_idx], k)
|
||||||
|
|
||||||
|
def search_by_img(self, query, k=10):
|
||||||
|
print('----------query features-------')
|
||||||
|
print(query)
|
||||||
|
queries = np.reshape(query, (-1, 2048))
|
||||||
|
print(queries)
|
||||||
|
scores, indexes = self.index.search(queries, k)
|
||||||
|
#dot_product = np.dot(self.descs, query)
|
||||||
|
#idx = dot_product.argsort()[::-1][:k]
|
||||||
|
res = []
|
||||||
|
for (i,j) in zip(indexes[0], scores[0]):
|
||||||
|
res.append((self.ids[i], round(float(j), 3)))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def save(self, is_backup=False):
|
||||||
|
descs_file = settings.DATASET
|
||||||
|
ids_file = settings.DATASET_IDS
|
||||||
|
|
||||||
|
if is_backup:
|
||||||
|
descs_file += '.bak'
|
||||||
|
ids_file += '.bak'
|
||||||
|
|
||||||
|
np.save(descs_file, self.descs)
|
||||||
|
np.savetxt(ids_file, self.ids, fmt='%s')
|
|
@ -0,0 +1,68 @@
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pickle as pickle
|
||||||
|
|
||||||
|
import LFUtilities
|
||||||
|
import beniculturaliSettings as settings
|
||||||
|
from BeniCulturaliRescorer import BeniCulturaliRescorer
|
||||||
|
from FAISSSearchEngine import FAISSSearchEngine
|
||||||
|
import FeatureExtractor as fe
|
||||||
|
import ORBExtractor as lf
|
||||||
|
|
||||||
|
|
||||||
|
class Searcher:
|
||||||
|
K_REORDERING = 15
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# self.dataset = h5py.File(settings.dataset_file, 'r')['rmac'][...]
|
||||||
|
|
||||||
|
# np.save('/media/Data/data/beni_culturali/deploy/dataset', self.dataset)
|
||||||
|
self.search_engine = FAISSSearchEngine()
|
||||||
|
#self.rescorer = BeniCulturaliRescorer()
|
||||||
|
|
||||||
|
def get_id(self, idx):
|
||||||
|
return self.search_engine.get_id(idx)
|
||||||
|
|
||||||
|
def add(self, img_file, id):
|
||||||
|
self.save(True)
|
||||||
|
|
||||||
|
desc = fe.extract(img_file)
|
||||||
|
#orb = lf.extract(img_file)
|
||||||
|
self.search_engine.add(desc, id)
|
||||||
|
#self.rescorer.add(orb)
|
||||||
|
|
||||||
|
self.save()
|
||||||
|
print('added ' + id)
|
||||||
|
|
||||||
|
def remove(self, id):
|
||||||
|
self.save(True)
|
||||||
|
self.search_engine.remove(id)
|
||||||
|
#self.rescorer.remove(idx)
|
||||||
|
self.save()
|
||||||
|
print('removed ' + id)
|
||||||
|
|
||||||
|
def search_by_id(self, query_id, k=10, rescorer=False):
|
||||||
|
kq = k
|
||||||
|
if rescorer:
|
||||||
|
kq = self.K_REORDERING
|
||||||
|
res = self.search_engine.search_by_id(query_id, kq)
|
||||||
|
# if rescorer:
|
||||||
|
# res_lf = self.rescorer.rescore_by_id(query_id, res)
|
||||||
|
# res = res_lf if res_lf else res[:k]
|
||||||
|
return res
|
||||||
|
|
||||||
|
def search_by_img(self, query_img, k=10, rescorer=False):
|
||||||
|
kq = k
|
||||||
|
if rescorer:
|
||||||
|
kq = self.K_REORDERING
|
||||||
|
query_desc = fe.extract(query_img)
|
||||||
|
res = self.search_engine.search_by_img(query_desc, kq)
|
||||||
|
#if rescorer:
|
||||||
|
# query_lf = lf.extract(query_img)
|
||||||
|
# res_lf = self.rescorer.rescore_by_img(query_lf, res)
|
||||||
|
# res = res_lf if res_lf else res[:k]
|
||||||
|
return res
|
||||||
|
|
||||||
|
def save(self, is_backup=False):
|
||||||
|
self.search_engine.save(is_backup)
|
||||||
|
#self.rescorer.save(is_backup)
|
|
@ -1,5 +1,3 @@
|
||||||
from re import split
|
|
||||||
|
|
||||||
from flask import Flask, request, redirect, url_for, flash, render_template, send_from_directory, abort
|
from flask import Flask, request, redirect, url_for, flash, render_template, send_from_directory, abort
|
||||||
from random import randint
|
from random import randint
|
||||||
import cv2
|
import cv2
|
||||||
|
@ -9,13 +7,12 @@ import json
|
||||||
|
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
from BeniCulturaliSearcher import BeniCulturaliSearcher
|
#from BeniCulturaliSearcher import BeniCulturaliSearcher
|
||||||
|
from Searcher import Searcher
|
||||||
from BeniCulturaliSearchEngine import BeniCulturaliSearchEngine
|
from BeniCulturaliSearchEngine import BeniCulturaliSearchEngine
|
||||||
import beniculturaliSettings as settings
|
import beniculturaliSettings as settings
|
||||||
import uuid
|
import uuid
|
||||||
import requests
|
|
||||||
import os, os.path
|
import os, os.path
|
||||||
from PIL import Image
|
|
||||||
import tornado.wsgi
|
import tornado.wsgi
|
||||||
import tornado.httpserver
|
import tornado.httpserver
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -29,7 +26,7 @@ def api_root():
|
||||||
print('index_with_randoms.html')
|
print('index_with_randoms.html')
|
||||||
random_ids = []
|
random_ids = []
|
||||||
for i in range(0, 15):
|
for i in range(0, 15):
|
||||||
random_ids.append(searcher.get_id(randint(0, 3000)))
|
random_ids.append(searcher.get_id(randint(0, 30)))
|
||||||
return render_template('index_with_randoms.html', random_ids=random_ids)
|
return render_template('index_with_randoms.html', random_ids=random_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,7 +185,7 @@ def start_from_terminal(app):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
settings.load_setting(args.conf)
|
settings.load_setting(args.conf)
|
||||||
global searcher
|
global searcher
|
||||||
searcher = BeniCulturaliSearcher()
|
searcher = Searcher()
|
||||||
|
|
||||||
#if args.debug:
|
#if args.debug:
|
||||||
# app.run(debug=True, host='0.0.0.0', port=settings.port)
|
# app.run(debug=True, host='0.0.0.0', port=settings.port)
|
||||||
|
|
Loading…
Reference in New Issue