branching to devel
This commit is contained in:
parent
526cf80b66
commit
5906f85f33
|
@ -330,6 +330,9 @@ class RecurrentEmbedder:
|
|||
|
||||
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
|
||||
print('### Gated Recurrent Unit View Generator (G)')
|
||||
# could be better to init model here at first .fit() call!
|
||||
if self.model is None:
|
||||
print('TODO: Init model!')
|
||||
if not self.is_trained:
|
||||
# Batchify input
|
||||
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)
|
||||
|
|
|
@ -27,8 +27,7 @@ if __name__ == '__main__':
|
|||
op.gruViewGenerator, op.gruMUSE, op.gruWCE, op.agg, op.allprob)
|
||||
print(f'Method: gFun{method_name}\nDataset: {dataset_name}')
|
||||
print('-'*50)
|
||||
exit()
|
||||
|
||||
|
||||
# set zscore range - is slice(0, 0) mean will be equal to 0 and std to 1, thus normalization will have no effect
|
||||
standardize_range = slice(0, 0)
|
||||
if op.zscore:
|
||||
|
@ -36,7 +35,7 @@ if __name__ == '__main__':
|
|||
|
||||
# load dataset
|
||||
data = MultilingualDataset.load(dataset)
|
||||
# data.set_view(languages=['fr', 'it']) # TODO: DEBUG SETTING
|
||||
data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
|
||||
data.show_dimensions()
|
||||
lXtr, lytr = data.training()
|
||||
lXte, lyte = data.test()
|
||||
|
@ -87,6 +86,7 @@ if __name__ == '__main__':
|
|||
document embeddings are then casted into vectors of posterior probabilities via a set of SVM.
|
||||
NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob
|
||||
"""
|
||||
op.gru_path = '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' # TODO DEBUG
|
||||
rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data,
|
||||
options=op, model_path=op.gru_path)
|
||||
doc_embedder.append(rnn_embedder)
|
||||
|
@ -95,6 +95,7 @@ if __name__ == '__main__':
|
|||
"""
|
||||
View generator (-B): generates document embedding via mBERT model.
|
||||
"""
|
||||
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
|
||||
mbert = MBertEmbedder(path_to_model=op.bert_path,
|
||||
nC=data.num_categories())
|
||||
if op.allprob:
|
||||
|
|
|
@ -45,7 +45,7 @@ def average_results(l_eval, show=True):
|
|||
|
||||
|
||||
def evaluate_method(polylingual_method, lX, ly, predictor=None, soft=False, return_time=False):
|
||||
tinit=time.time()
|
||||
tinit = time.time()
|
||||
print('prediction for test')
|
||||
assert set(lX.keys()) == set(ly.keys()), 'inconsistent dictionaries in evaluate'
|
||||
n_jobs = polylingual_method.n_jobs if hasattr(polylingual_method, 'n_jobs') else -1
|
||||
|
|
Loading…
Reference in New Issue