branching to devel
This commit is contained in:
parent
526cf80b66
commit
5906f85f33
|
@ -330,6 +330,9 @@ class RecurrentEmbedder:
|
||||||
|
|
||||||
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
|
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
|
||||||
print('### Gated Recurrent Unit View Generator (G)')
|
print('### Gated Recurrent Unit View Generator (G)')
|
||||||
|
# could be better to init model here at first .fit() call!
|
||||||
|
if self.model is None:
|
||||||
|
print('TODO: Init model!')
|
||||||
if not self.is_trained:
|
if not self.is_trained:
|
||||||
# Batchify input
|
# Batchify input
|
||||||
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)
|
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)
|
||||||
|
|
|
@ -27,8 +27,7 @@ if __name__ == '__main__':
|
||||||
op.gruViewGenerator, op.gruMUSE, op.gruWCE, op.agg, op.allprob)
|
op.gruViewGenerator, op.gruMUSE, op.gruWCE, op.agg, op.allprob)
|
||||||
print(f'Method: gFun{method_name}\nDataset: {dataset_name}')
|
print(f'Method: gFun{method_name}\nDataset: {dataset_name}')
|
||||||
print('-'*50)
|
print('-'*50)
|
||||||
exit()
|
|
||||||
|
|
||||||
# set zscore range - is slice(0, 0) mean will be equal to 0 and std to 1, thus normalization will have no effect
|
# set zscore range - is slice(0, 0) mean will be equal to 0 and std to 1, thus normalization will have no effect
|
||||||
standardize_range = slice(0, 0)
|
standardize_range = slice(0, 0)
|
||||||
if op.zscore:
|
if op.zscore:
|
||||||
|
@ -36,7 +35,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
data = MultilingualDataset.load(dataset)
|
data = MultilingualDataset.load(dataset)
|
||||||
# data.set_view(languages=['fr', 'it']) # TODO: DEBUG SETTING
|
data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
lXtr, lytr = data.training()
|
lXtr, lytr = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
@ -87,6 +86,7 @@ if __name__ == '__main__':
|
||||||
document embeddings are then casted into vectors of posterior probabilities via a set of SVM.
|
document embeddings are then casted into vectors of posterior probabilities via a set of SVM.
|
||||||
NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob
|
NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob
|
||||||
"""
|
"""
|
||||||
|
op.gru_path = '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' # TODO DEBUG
|
||||||
rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data,
|
rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data,
|
||||||
options=op, model_path=op.gru_path)
|
options=op, model_path=op.gru_path)
|
||||||
doc_embedder.append(rnn_embedder)
|
doc_embedder.append(rnn_embedder)
|
||||||
|
@ -95,6 +95,7 @@ if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
View generator (-B): generates document embedding via mBERT model.
|
View generator (-B): generates document embedding via mBERT model.
|
||||||
"""
|
"""
|
||||||
|
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
|
||||||
mbert = MBertEmbedder(path_to_model=op.bert_path,
|
mbert = MBertEmbedder(path_to_model=op.bert_path,
|
||||||
nC=data.num_categories())
|
nC=data.num_categories())
|
||||||
if op.allprob:
|
if op.allprob:
|
||||||
|
|
|
@ -45,7 +45,7 @@ def average_results(l_eval, show=True):
|
||||||
|
|
||||||
|
|
||||||
def evaluate_method(polylingual_method, lX, ly, predictor=None, soft=False, return_time=False):
|
def evaluate_method(polylingual_method, lX, ly, predictor=None, soft=False, return_time=False):
|
||||||
tinit=time.time()
|
tinit = time.time()
|
||||||
print('prediction for test')
|
print('prediction for test')
|
||||||
assert set(lX.keys()) == set(ly.keys()), 'inconsistent dictionaries in evaluate'
|
assert set(lX.keys()) == set(ly.keys()), 'inconsistent dictionaries in evaluate'
|
||||||
n_jobs = polylingual_method.n_jobs if hasattr(polylingual_method, 'n_jobs') else -1
|
n_jobs = polylingual_method.n_jobs if hasattr(polylingual_method, 'n_jobs') else -1
|
||||||
|
|
Loading…
Reference in New Issue