62 lines
1.7 KiB
Python
Executable File
62 lines
1.7 KiB
Python
Executable File
import h5py
|
|
import numpy as np
|
|
import WebAppSettings as settings
|
|
|
|
|
|
class GEMSearcher:
|
|
|
|
|
|
def __init__(self):
|
|
#self.dataset = h5py.File(settings.dataset_file, 'r')['rmac'][...]
|
|
|
|
#np.save('/media/Data/data/beni_culturali/deploy/dataset', self.dataset)
|
|
self.descs = np.load(settings.DATASET_GEM)
|
|
#self.desc1 = np.load(settings.DATASET1)
|
|
#self.desc2 = np.load(settings.DATASET2)
|
|
|
|
#self.descs = (self.desc1 + self.desc2) / 2
|
|
#self.descs /= np.linalg.norm(self.descs, axis=1, keepdims=True)
|
|
self.ids = np.loadtxt(settings.DATASET_IDS, dtype=str).tolist()
|
|
|
|
|
|
def get_id(self, idx):
|
|
return self.ids[idx]
|
|
|
|
|
|
def add(self, desc, id):
|
|
self.ids.append(id)
|
|
self.descs = np.vstack((self.descs, desc))
|
|
self.save()
|
|
|
|
|
|
def remove(self, id):
|
|
idx = self.ids.index(id)
|
|
del self.ids[idx]
|
|
self.descs = np.delete(self.descs, idx, axis=0)
|
|
|
|
|
|
def search_by_id(self, query_id, k=10):
|
|
query_idx = self.ids.index(query_id)
|
|
return self.search_by_img(self.descs[query_idx], k)
|
|
|
|
def search_by_img(self, query, k=10):
|
|
# print('----------query features-------')
|
|
#print(query)
|
|
dot_product = np.dot(self.descs, query[0])
|
|
idx = dot_product.argsort()[::-1][:k]
|
|
res = []
|
|
for i in idx:
|
|
res.append((self.ids[i], round(float(dot_product[i]), 3)))
|
|
return res
|
|
|
|
def save(self, is_backup=False):
|
|
descs_file = settings.DATASET
|
|
ids_file = settings.DATASET_IDS
|
|
|
|
if is_backup:
|
|
descs_file += '.bak'
|
|
ids_file += '.bak'
|
|
|
|
np.save(descs_file, self.descs)
|
|
np.savetxt(ids_file, self.ids, fmt='%s')
|