branching to devel

This commit is contained in:
andrea 2020-10-27 11:12:26 +01:00
parent 526cf80b66
commit 5906f85f33
3 changed files with 8 additions and 4 deletions

View File

@ -330,6 +330,9 @@ class RecurrentEmbedder:
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
print('### Gated Recurrent Unit View Generator (G)')
# could be better to init model here at first .fit() call!
if self.model is None:
print('TODO: Init model!')
if not self.is_trained:
# Batchify input
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)

View File

@ -27,7 +27,6 @@ if __name__ == '__main__':
op.gruViewGenerator, op.gruMUSE, op.gruWCE, op.agg, op.allprob)
print(f'Method: gFun{method_name}\nDataset: {dataset_name}')
print('-'*50)
exit()
# set zscore range - is slice(0, 0) mean will be equal to 0 and std to 1, thus normalization will have no effect
standardize_range = slice(0, 0)
@ -36,7 +35,7 @@ if __name__ == '__main__':
# load dataset
data = MultilingualDataset.load(dataset)
# data.set_view(languages=['fr', 'it']) # TODO: DEBUG SETTING
data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
data.show_dimensions()
lXtr, lytr = data.training()
lXte, lyte = data.test()
@ -87,6 +86,7 @@ if __name__ == '__main__':
document embeddings are then casted into vectors of posterior probabilities via a set of SVM.
NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob
"""
op.gru_path = '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' # TODO DEBUG
rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data,
options=op, model_path=op.gru_path)
doc_embedder.append(rnn_embedder)
@ -95,6 +95,7 @@ if __name__ == '__main__':
"""
View generator (-B): generates document embedding via mBERT model.
"""
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
mbert = MBertEmbedder(path_to_model=op.bert_path,
nC=data.num_categories())
if op.allprob:

View File

@ -45,7 +45,7 @@ def average_results(l_eval, show=True):
def evaluate_method(polylingual_method, lX, ly, predictor=None, soft=False, return_time=False):
tinit=time.time()
tinit = time.time()
print('prediction for test')
assert set(lX.keys()) == set(ly.keys()), 'inconsistent dictionaries in evaluate'
n_jobs = polylingual_method.n_jobs if hasattr(polylingual_method, 'n_jobs') else -1