diff --git a/src/learning/transformers.py b/src/learning/transformers.py index c669389..0032460 100644 --- a/src/learning/transformers.py +++ b/src/learning/transformers.py @@ -330,6 +330,9 @@ class RecurrentEmbedder: def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1): print('### Gated Recurrent Unit View Generator (G)') + # could be better to init model here at first .fit() call! + if self.model is None: + print('TODO: Init model!') if not self.is_trained: # Batchify input self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed) diff --git a/src/main_gFun.py b/src/main_gFun.py index 04ae86a..c671ecd 100644 --- a/src/main_gFun.py +++ b/src/main_gFun.py @@ -27,8 +27,7 @@ if __name__ == '__main__': op.gruViewGenerator, op.gruMUSE, op.gruWCE, op.agg, op.allprob) print(f'Method: gFun{method_name}\nDataset: {dataset_name}') print('-'*50) - exit() - + # set zscore range - is slice(0, 0) mean will be equal to 0 and std to 1, thus normalization will have no effect standardize_range = slice(0, 0) if op.zscore: @@ -36,7 +35,7 @@ if __name__ == '__main__': # load dataset data = MultilingualDataset.load(dataset) - # data.set_view(languages=['fr', 'it']) # TODO: DEBUG SETTING + data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING data.show_dimensions() lXtr, lytr = data.training() lXte, lyte = data.test() @@ -87,6 +86,7 @@ if __name__ == '__main__': document embeddings are then casted into vectors of posterior probabilities via a set of SVM. NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob """ + op.gru_path = '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' # TODO DEBUG rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data, options=op, model_path=op.gru_path) doc_embedder.append(rnn_embedder) @@ -95,6 +95,7 @@ if __name__ == '__main__': """ View generator (-B): generates document embedding via mBERT model. """ + op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG mbert = MBertEmbedder(path_to_model=op.bert_path, nC=data.num_categories()) if op.allprob: diff --git a/src/util/evaluation.py b/src/util/evaluation.py index a4aac5c..41a2813 100644 --- a/src/util/evaluation.py +++ b/src/util/evaluation.py @@ -45,7 +45,7 @@ def average_results(l_eval, show=True): def evaluate_method(polylingual_method, lX, ly, predictor=None, soft=False, return_time=False): - tinit=time.time() + tinit = time.time() print('prediction for test') assert set(lX.keys()) == set(ly.keys()), 'inconsistent dictionaries in evaluate' n_jobs = polylingual_method.n_jobs if hasattr(polylingual_method, 'n_jobs') else -1