implemented zero-shot experiment code for VanillaFunGen and WordClassGen
This commit is contained in:
parent
c65c91fc27
commit
7affa1fab4
31
main.py
31
main.py
|
|
@ -11,17 +11,19 @@ from src.view_generators import *
|
||||||
def main(args):
|
def main(args):
|
||||||
assert args.post_embedder or args.muse_embedder or args.wce_embedder or args.gru_embedder or args.bert_embedder, \
|
assert args.post_embedder or args.muse_embedder or args.wce_embedder or args.gru_embedder or args.bert_embedder, \
|
||||||
'empty set of document embeddings is not allowed!'
|
'empty set of document embeddings is not allowed!'
|
||||||
|
assert not (args.zero_shot and (args.zscl_langs is None)), \
|
||||||
|
'--zscl_langs cannot be empty when setting --zero_shot to True'
|
||||||
|
|
||||||
print('Running generalized funnelling...')
|
print('Running generalized funnelling...')
|
||||||
|
|
||||||
data = MultilingualDataset.load(args.dataset)
|
data = MultilingualDataset.load(args.dataset)
|
||||||
data.set_view(languages=['it', 'da', 'nl'])
|
data.set_view(languages=['nl'])
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
zero_shot = True
|
zero_shot = args.zero_shot
|
||||||
zscl_train_langs = ['it'] # Todo: testing zero shot
|
zscl_train_langs = args.zscl_langs
|
||||||
|
|
||||||
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
||||||
if args.gru_embedder or args.bert_embedder:
|
if args.gru_embedder or args.bert_embedder:
|
||||||
|
|
@ -37,24 +39,24 @@ def main(args):
|
||||||
|
|
||||||
if args.muse_embedder:
|
if args.muse_embedder:
|
||||||
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
|
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
|
||||||
zero_shot=zero_shot, train_langs=zscl_train_langs) # Todo: testing zero shot
|
zero_shot=zero_shot, train_langs=zscl_train_langs)
|
||||||
embedder_list.append(museEmbedder)
|
embedder_list.append(museEmbedder)
|
||||||
|
|
||||||
if args.wce_embedder:
|
if args.wce_embedder:
|
||||||
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
|
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
|
||||||
zero_shot=zero_shot, train_langs=zscl_train_langs) # Todo: testing zero shot
|
zero_shot=zero_shot, train_langs=zscl_train_langs)
|
||||||
embedder_list.append(wceEmbedder)
|
embedder_list.append(wceEmbedder)
|
||||||
|
|
||||||
if args.gru_embedder:
|
if args.gru_embedder:
|
||||||
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
|
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
|
||||||
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
|
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
|
||||||
zero_shot=zero_shot, train_langs=zscl_train_langs, # Todo: testing zero shot
|
zero_shot=zero_shot, train_langs=zscl_train_langs,
|
||||||
gpus=args.gpus, n_jobs=args.n_jobs)
|
gpus=args.gpus, n_jobs=args.n_jobs)
|
||||||
embedder_list.append(rnnEmbedder)
|
embedder_list.append(rnnEmbedder)
|
||||||
|
|
||||||
if args.bert_embedder:
|
if args.bert_embedder:
|
||||||
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
|
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
|
||||||
zero_shot=zero_shot, train_langs=zscl_train_langs, # Todo: testing zero shot
|
zero_shot=zero_shot, train_langs=zscl_train_langs,
|
||||||
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
|
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
|
||||||
embedder_list.append(bertEmbedder)
|
embedder_list.append(bertEmbedder)
|
||||||
|
|
||||||
|
|
@ -76,7 +78,7 @@ def main(args):
|
||||||
# Testing ----------------------------------------
|
# Testing ----------------------------------------
|
||||||
print('\n[Testing Generalized Funnelling]')
|
print('\n[Testing Generalized Funnelling]')
|
||||||
time_te = time.time()
|
time_te = time.time()
|
||||||
# TODO: Zero shot scenario -> setting first tier learners zero_shot param to False
|
if args.zero_shot:
|
||||||
gfun.set_zero_shot(val=False)
|
gfun.set_zero_shot(val=False)
|
||||||
ly_ = gfun.predict(lXte)
|
ly_ = gfun.predict(lXte)
|
||||||
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
||||||
|
|
@ -85,7 +87,7 @@ def main(args):
|
||||||
|
|
||||||
# Logging ---------------------------------------
|
# Logging ---------------------------------------
|
||||||
print('\n[Results]')
|
print('\n[Results]')
|
||||||
results = CSVlog(args.csv_dir)
|
results = CSVlog(f'csv_logs/gfun/{args.csv_dir}')
|
||||||
metrics = []
|
metrics = []
|
||||||
for lang in lXte.keys():
|
for lang in lXte.keys():
|
||||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||||
|
|
@ -120,8 +122,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('dataset', help='Path to the dataset')
|
parser.add_argument('dataset', help='Path to the dataset')
|
||||||
|
|
||||||
parser.add_argument('-o', '--output', dest='csv_dir', metavar='',
|
parser.add_argument('-o', '--output', dest='csv_dir', metavar='',
|
||||||
help='Result file (default csv_logs/gfun/gfun_results.csv)', type=str,
|
help='Result file saved in csv_logs/gfun/dir, default is gfun_results.csv)', type=str,
|
||||||
default='csv_logs/gfun/gfun_results.csv')
|
default='gfun_results.csv')
|
||||||
|
|
||||||
parser.add_argument('-x', '--post_embedder', dest='post_embedder', action='store_true',
|
parser.add_argument('-x', '--post_embedder', dest='post_embedder', action='store_true',
|
||||||
help='deploy posterior probabilities embedder to compute document embeddings',
|
help='deploy posterior probabilities embedder to compute document embeddings',
|
||||||
|
|
@ -194,5 +196,12 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--gpus', metavar='', help='specifies how many GPUs to use per node',
|
parser.add_argument('--gpus', metavar='', help='specifies how many GPUs to use per node',
|
||||||
default=None)
|
default=None)
|
||||||
|
|
||||||
|
parser.add_argument('--zero_shot', dest='zero_shot', action='store_true',
|
||||||
|
help='run zero-shot experiments',
|
||||||
|
default=False)
|
||||||
|
|
||||||
|
parser.add_argument('--zscl_langs', dest='zscl_langs', metavar='', nargs='*',
|
||||||
|
help='set the languages to be used in training in zero shot experiments')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
||||||
11
run.sh
11
run.sh
|
|
@ -2,7 +2,16 @@
|
||||||
|
|
||||||
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
|
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
|
||||||
|
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -o csv_logs/gfun/zero_shot_gfun.csv --gpus 0
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6
|
||||||
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6
|
||||||
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6
|
||||||
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6
|
||||||
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6
|
||||||
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6
|
||||||
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6
|
||||||
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6
|
||||||
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6
|
||||||
|
|
||||||
|
|
||||||
#for i in {0..10..1}
|
#for i in {0..10..1}
|
||||||
#do
|
#do
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,9 @@ class Funnelling:
|
||||||
|
|
||||||
def set_zero_shot(self, val: bool):
|
def set_zero_shot(self, val: bool):
|
||||||
for embedder in self.first_tier.embedders:
|
for embedder in self.first_tier.embedders:
|
||||||
|
if isinstance(embedder, VanillaFunGen):
|
||||||
|
embedder.set_zero_shot(val)
|
||||||
|
else:
|
||||||
embedder.embedder.set_zero_shot(val)
|
embedder.embedder.set_zero_shot(val)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ class VanillaFunGen(ViewGen):
|
||||||
View Generator (x): original funnelling architecture proposed by Moreo, Esuli and
|
View Generator (x): original funnelling architecture proposed by Moreo, Esuli and
|
||||||
Sebastiani in DOI: https://doi.org/10.1145/3326065
|
Sebastiani in DOI: https://doi.org/10.1145/3326065
|
||||||
"""
|
"""
|
||||||
def __init__(self, base_learner, first_tier_parameters=None, n_jobs=-1):
|
def __init__(self, base_learner, first_tier_parameters=None, zero_shot=False, train_langs: list = None, n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
Init Posterior Probabilities embedder (i.e., VanillaFunGen)
|
Init Posterior Probabilities embedder (i.e., VanillaFunGen)
|
||||||
:param base_learner: naive monolingual learners to be deployed as first-tier learners. Should be able to
|
:param base_learner: naive monolingual learners to be deployed as first-tier learners. Should be able to
|
||||||
|
|
@ -71,9 +71,19 @@ class VanillaFunGen(ViewGen):
|
||||||
self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners,
|
self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners,
|
||||||
parameters=self.first_tier_parameters, n_jobs=self.n_jobs)
|
parameters=self.first_tier_parameters, n_jobs=self.n_jobs)
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
# Zero shot parameters
|
||||||
|
self.zero_shot = zero_shot
|
||||||
|
if train_langs is None:
|
||||||
|
train_langs = ['it']
|
||||||
|
self.train_langs = train_langs
|
||||||
|
|
||||||
def fit(self, lX, lY):
|
def fit(self, lX, lY):
|
||||||
print('# Fitting VanillaFunGen (X)...')
|
print('# Fitting VanillaFunGen (X)...')
|
||||||
|
if self.zero_shot:
|
||||||
|
self.langs = sorted(self.train_langs)
|
||||||
|
lX = self.zero_shot_experiments(lX)
|
||||||
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
|
else:
|
||||||
lX = self.vectorizer.fit_transform(lX)
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
self.doc_projector.fit(lX, lY)
|
self.doc_projector.fit(lX, lY)
|
||||||
return self
|
return self
|
||||||
|
|
@ -93,9 +103,19 @@ class VanillaFunGen(ViewGen):
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
def zero_shot_experiments(self, lX):
|
||||||
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
_lX = {}
|
||||||
|
for lang in self.langs:
|
||||||
|
if lang in self.train_langs:
|
||||||
|
_lX[lang] = lX[lang]
|
||||||
|
else:
|
||||||
|
_lX[lang] = None
|
||||||
|
lX = _lX
|
||||||
|
return lX
|
||||||
|
|
||||||
def set_zero_shot(self, val: bool):
|
def set_zero_shot(self, val: bool):
|
||||||
self.zero_shot = val
|
self.zero_shot = val
|
||||||
print('# TODO: PosteriorsGen has not been configured for zero-shot experiments')
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -205,8 +225,14 @@ class WordClassGen(ViewGen):
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting WordClassGen (W)...')
|
print('# Fitting WordClassGen (W)...')
|
||||||
|
if self.zero_shot:
|
||||||
|
self.langs = sorted(self.train_langs)
|
||||||
|
lX = self.zero_shot_experiments(lX)
|
||||||
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
|
else:
|
||||||
lX = self.vectorizer.fit_transform(lX)
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
self.langs = sorted(lX.keys())
|
self.langs = sorted(lX.keys())
|
||||||
|
|
||||||
wce = Parallel(n_jobs=self.n_jobs)(
|
wce = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(wce_matrix)(lX[lang], ly[lang]) for lang in self.langs)
|
delayed(wce_matrix)(lX[lang], ly[lang]) for lang in self.langs)
|
||||||
self.lWce = {l: wce[i] for i, l in enumerate(self.langs)}
|
self.lWce = {l: wce[i] for i, l in enumerate(self.langs)}
|
||||||
|
|
@ -220,15 +246,10 @@ class WordClassGen(ViewGen):
|
||||||
:param lX: dict {lang: indexed documents}
|
:param lX: dict {lang: indexed documents}
|
||||||
:return: document projection to the common latent space.
|
:return: document projection to the common latent space.
|
||||||
"""
|
"""
|
||||||
# Testing zero-shot experiments
|
|
||||||
if self.zero_shot:
|
|
||||||
lX = self.zero_shot_experiments(lX)
|
|
||||||
lX = {l: self.vectorizer.vectorizer[l].transform(lX[l]) for l in self.langs if lX[l] is not None}
|
|
||||||
else:
|
|
||||||
lX = self.vectorizer.transform(lX)
|
lX = self.vectorizer.transform(lX)
|
||||||
XdotWce = Parallel(n_jobs=self.n_jobs)(
|
XdotWce = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in sorted(lX.keys()))
|
delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in sorted(lX.keys()) if lang in self.lWce.keys())
|
||||||
lWce = {l: XdotWce[i] for i, l in enumerate(sorted(lX.keys()))}
|
lWce = {l: XdotWce[i] for i, l in enumerate(sorted(lX.keys())) if l in self.lWce.keys()}
|
||||||
lWce = _normalize(lWce, l2=True)
|
lWce = _normalize(lWce, l2=True)
|
||||||
return lWce
|
return lWce
|
||||||
|
|
||||||
|
|
@ -339,7 +360,7 @@ class RecurrentGen(ViewGen):
|
||||||
print('# Fitting RecurrentGen (G)...')
|
print('# Fitting RecurrentGen (G)...')
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
|
||||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs) # Todo: zero shot settings
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||||
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
|
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
|
||||||
|
|
||||||
|
|
@ -350,7 +371,7 @@ class RecurrentGen(ViewGen):
|
||||||
# self.model.linear2 = vanilla_torch_model.linear2
|
# self.model.linear2 = vanilla_torch_model.linear2
|
||||||
# self.model.rnn = vanilla_torch_model.rnn
|
# self.model.rnn = vanilla_torch_model.rnn
|
||||||
|
|
||||||
if self.zero_shot: # Todo: zero shot experiment setting
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
|
||||||
trainer.fit(self.model, datamodule=recurrentDataModule)
|
trainer.fit(self.model, datamodule=recurrentDataModule)
|
||||||
|
|
@ -451,7 +472,7 @@ class BertGen(ViewGen):
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||||
|
|
||||||
if self.zero_shot: # Todo: zero shot experiment setting
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue