minor fixes
This commit is contained in:
parent
94bfe6a036
commit
20dca61e22
|
|
@ -822,10 +822,12 @@ def clip_gradient(model, clip_value=1e-1):
|
||||||
|
|
||||||
|
|
||||||
def init_logfile_nn(method_name, opt):
|
def init_logfile_nn(method_name, opt):
|
||||||
|
import os
|
||||||
logfile = CSVLog(opt.logfile_gru, ['dataset', 'method', 'epoch', 'measure', 'value', 'run', 'timelapse'])
|
logfile = CSVLog(opt.logfile_gru, ['dataset', 'method', 'epoch', 'measure', 'value', 'run', 'timelapse'])
|
||||||
logfile.set_default('dataset', opt.dataset)
|
logfile.set_default('dataset', opt.dataset)
|
||||||
logfile.set_default('run', opt.seed)
|
logfile.set_default('run', opt.seed)
|
||||||
logfile.set_default('method', method_name)
|
logfile.set_default('method', get_method_name(os.path.basename(opt.dataset), opt.posteriors, opt.supervised, opt.pretrained, opt.mbert,
|
||||||
|
opt.gruViewGenerator, opt.gruMUSE, opt.gruWCE, opt.agg, opt.allprob))
|
||||||
assert opt.force or not logfile.already_calculated(), f'results for dataset {opt.dataset} method {method_name} ' \
|
assert opt.force or not logfile.already_calculated(), f'results for dataset {opt.dataset} method {method_name} ' \
|
||||||
f'and run {opt.seed} already calculated'
|
f'and run {opt.seed} already calculated'
|
||||||
return logfile
|
return logfile
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@ if __name__ == '__main__':
|
||||||
assert not (op.set_c != 1. and op.optimc), 'Parameter C cannot be defined along with optim_c option'
|
assert not (op.set_c != 1. and op.optimc), 'Parameter C cannot be defined along with optim_c option'
|
||||||
assert op.posteriors or op.supervised or op.pretrained or op.mbert or op.gruViewGenerator, \
|
assert op.posteriors or op.supervised or op.pretrained or op.mbert or op.gruViewGenerator, \
|
||||||
'empty set of document embeddings is not allowed'
|
'empty set of document embeddings is not allowed'
|
||||||
assert (op.gruWCE or op.gruMUSE) and op.gruViewGenerator, 'Initializing Gated Recurrent embedding layer without ' \
|
if op.gruViewGenerator:
|
||||||
|
assert op.gruWCE or op.gruMUSE, 'Initializing Gated Recurrent embedding layer without ' \
|
||||||
'explicit initialization of GRU View Generator'
|
'explicit initialization of GRU View Generator'
|
||||||
|
|
||||||
l2 = op.l2
|
l2 = op.l2
|
||||||
|
|
@ -35,7 +36,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
data = MultilingualDataset.load(dataset)
|
data = MultilingualDataset.load(dataset)
|
||||||
data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
|
# data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
lXtr, lytr = data.training()
|
lXtr, lytr = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
@ -86,7 +87,6 @@ if __name__ == '__main__':
|
||||||
document embeddings are then casted into vectors of posterior probabilities via a set of SVM.
|
document embeddings are then casted into vectors of posterior probabilities via a set of SVM.
|
||||||
NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob
|
NB: --allprob won't have any effect on this View Gen since output is already encoded as post prob
|
||||||
"""
|
"""
|
||||||
op.gru_path = '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' # TODO DEBUG
|
|
||||||
rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data,
|
rnn_embedder = RecurrentEmbedder(pretrained=op.gruMUSE, supervised=op.gruWCE, multilingual_dataset=data,
|
||||||
options=op, model_path=op.gru_path)
|
options=op, model_path=op.gru_path)
|
||||||
doc_embedder.append(rnn_embedder)
|
doc_embedder.append(rnn_embedder)
|
||||||
|
|
@ -95,7 +95,6 @@ if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
View generator (-B): generates document embedding via mBERT model.
|
View generator (-B): generates document embedding via mBERT model.
|
||||||
"""
|
"""
|
||||||
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
|
|
||||||
mbert = MBertEmbedder(path_to_model=op.bert_path,
|
mbert = MBertEmbedder(path_to_model=op.bert_path,
|
||||||
nC=data.num_categories())
|
nC=data.num_categories())
|
||||||
if op.allprob:
|
if op.allprob:
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ parser.add_option('-G', dest='gruViewGenerator', action='store_true',
|
||||||
|
|
||||||
parser.add_option("--l2", dest="l2", action='store_true',
|
parser.add_option("--l2", dest="l2", action='store_true',
|
||||||
help="Activates l2 normalization as a post-processing for the document embedding views",
|
help="Activates l2 normalization as a post-processing for the document embedding views",
|
||||||
default=False)
|
default=True)
|
||||||
|
|
||||||
parser.add_option("--allprob", dest="allprob", action='store_true',
|
parser.add_option("--allprob", dest="allprob", action='store_true',
|
||||||
help="All views are generated as posterior probabilities. This affects the supervised and pretrained"
|
help="All views are generated as posterior probabilities. This affects the supervised and pretrained"
|
||||||
|
|
@ -51,10 +51,10 @@ parser.add_option("-p", "--pca", dest="max_labels_S", type=int,
|
||||||
default=300)
|
default=300)
|
||||||
|
|
||||||
parser.add_option("-r", "--remove-pc", dest="sif", action='store_true',
|
parser.add_option("-r", "--remove-pc", dest="sif", action='store_true',
|
||||||
help="Remove common component when computing dot product of word embedding matrices", default=False)
|
help="Remove common component when computing dot product of word embedding matrices", default=True)
|
||||||
|
|
||||||
parser.add_option("-z", "--zscore", dest="zscore", action='store_true',
|
parser.add_option("-z", "--zscore", dest="zscore", action='store_true',
|
||||||
help="Z-score normalize matrices (WCE and MUSE)", default=False)
|
help="Z-score normalize matrices (WCE and MUSE)", default=True)
|
||||||
|
|
||||||
parser.add_option("-a", "--agg", dest="agg", action='store_true',
|
parser.add_option("-a", "--agg", dest="agg", action='store_true',
|
||||||
help="Set aggregation function of the common Z-space to average (Default: concatenation)",
|
help="Set aggregation function of the common Z-space to average (Default: concatenation)",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue