branching to devel

This commit is contained in:
andrea 2020-10-27 11:12:26 +01:00
parent 526cf80b66
commit 5906f85f33
3 changed files with 8 additions and 4 deletions

View File

@ -330,6 +330,9 @@ class RecurrentEmbedder:
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1): def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
print('### Gated Recurrent Unit View Generator (G)') print('### Gated Recurrent Unit View Generator (G)')
# could be better to init model here at first .fit() call!
if self.model is None:
print('TODO: Init model!')
if not self.is_trained: if not self.is_trained:
# Batchify input # Batchify input
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed) self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)

View File

@ -27,8 +27,7 @@ if __name__ == '__main__':
op.gruViewGenerator, op.gruMUSE, op.gruWCE, op.agg, op.allprob) op.gruViewGenerator, op.gruMUSE, op.gruWCE, op.agg, op.allprob)
print(f'Method: gFun{method_name}\nDataset: {dataset_name}') print(f'Method: gFun{method_name}\nDataset: {dataset_name}')
print('-'*50) print('-'*50)
exit()
# set zscore range - is slice(0, 0) mean will be equal to 0 and std to 1, thus normalization will have no effect # set zscore range - is slice(0, 0) mean will be equal to 0 and std to 1, thus normalization will have no effect
standardize_range = slice(0, 0) standardize_range = slice(0, 0)
if op.zscore: if op.zscore:
@ -36,7 +35,7 @@ if __name__ == '__main__':
# load dataset # load dataset
data = MultilingualDataset.load(dataset) data = MultilingualDataset.load(dataset)
# data.set_view(languages=['fr', 'it']) # TODO: DEBUG SETTING data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
data.show_dimensions() data.show_dimensions()
lXtr, lytr = data.training() lXtr, lytr = data.training()
lXte, lyte = data.test() lXte, lyte = data.test()
@ -87,6 +86,7 @@ if __name__ == '__main__':
document embeddings are then casted into vectors of posterior probabilities via a set of SVM. document embeddings are then casted into vectors of posterior probabilities via a set of SVM.
NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob
""" """
op.gru_path = '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' # TODO DEBUG
rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data, rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data,
options=op, model_path=op.gru_path) options=op, model_path=op.gru_path)
doc_embedder.append(rnn_embedder) doc_embedder.append(rnn_embedder)
@ -95,6 +95,7 @@ if __name__ == '__main__':
""" """
View generator (-B): generates document embedding via mBERT model. View generator (-B): generates document embedding via mBERT model.
""" """
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
mbert = MBertEmbedder(path_to_model=op.bert_path, mbert = MBertEmbedder(path_to_model=op.bert_path,
nC=data.num_categories()) nC=data.num_categories())
if op.allprob: if op.allprob:

View File

@ -45,7 +45,7 @@ def average_results(l_eval, show=True):
def evaluate_method(polylingual_method, lX, ly, predictor=None, soft=False, return_time=False): def evaluate_method(polylingual_method, lX, ly, predictor=None, soft=False, return_time=False):
tinit=time.time() tinit = time.time()
print('prediction for test') print('prediction for test')
assert set(lX.keys()) == set(ly.keys()), 'inconsistent dictionaries in evaluate' assert set(lX.keys()) == set(ly.keys()), 'inconsistent dictionaries in evaluate'
n_jobs = polylingual_method.n_jobs if hasattr(polylingual_method, 'n_jobs') else -1 n_jobs = polylingual_method.n_jobs if hasattr(polylingual_method, 'n_jobs') else -1