setting up zero-shot experiments (implemented for Recurrent and Bert but not tested)
This commit is contained in:
parent
5821325c86
commit
ee98c5f610
1
main.py
1
main.py
|
|
@ -48,6 +48,7 @@ def main(args):
|
||||||
if args.gru_embedder:
|
if args.gru_embedder:
|
||||||
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
|
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
|
||||||
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
|
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
|
||||||
|
zero_shot=zero_shot, train_langs=zscl_train_langs, # Todo: testing zero shot
|
||||||
gpus=args.gpus, n_jobs=args.n_jobs)
|
gpus=args.gpus, n_jobs=args.n_jobs)
|
||||||
embedder_list.append(rnnEmbedder)
|
embedder_list.append(rnnEmbedder)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
Pytorch Lightning Datamodule to be deployed with RecurrentGen.
|
Pytorch Lightning Datamodule to be deployed with RecurrentGen.
|
||||||
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
"""
|
"""
|
||||||
def __init__(self, multilingualIndex, batchsize=64, n_jobs=-1):
|
def __init__(self, multilingualIndex, batchsize=64, n_jobs=-1, zero_shot=False, zscl_langs=None):
|
||||||
"""
|
"""
|
||||||
Init RecurrentDataModule.
|
Init RecurrentDataModule.
|
||||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
|
@ -103,6 +103,11 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
self.multilingualIndex = multilingualIndex
|
self.multilingualIndex = multilingualIndex
|
||||||
self.batchsize = batchsize
|
self.batchsize = batchsize
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
|
# Zero shot arguments
|
||||||
|
if zscl_langs is None:
|
||||||
|
zscl_langs = []
|
||||||
|
self.zero_shot = zero_shot
|
||||||
|
self.train_langs = zscl_langs
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def prepare_data(self, *args, **kwargs):
|
def prepare_data(self, *args, **kwargs):
|
||||||
|
|
@ -110,7 +115,10 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
if self.zero_shot:
|
||||||
|
l_train_index, l_train_target = self.multilingualIndex.l_train_zero_shot(langs=self.train_langs)
|
||||||
|
else:
|
||||||
|
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
# l_train_index = {l: train[:5] for l, train in l_train_index.items()}
|
# l_train_index = {l: train[:5] for l, train in l_train_index.items()}
|
||||||
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||||
|
|
@ -118,7 +126,10 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
if self.zero_shot:
|
||||||
|
l_val_index, l_val_target = self.multilingualIndex.l_val_zero_shot(langs=self.train_langs)
|
||||||
|
else:
|
||||||
|
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
# l_val_index = {l: train[:5] for l, train in l_val_index.items()}
|
# l_val_index = {l: train[:5] for l, train in l_val_index.items()}
|
||||||
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||||
|
|
@ -126,7 +137,10 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
l_test_index, l_test_target = self.multilingualIndex.l_test()
|
if self.zero_shot:
|
||||||
|
l_test_index, l_test_target = self.multilingualIndex.l_test_zero_shot(langs=self.train_langs)
|
||||||
|
else:
|
||||||
|
l_test_index, l_test_target = self.multilingualIndex.l_test()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
# l_test_index = {l: train[:5] for l, train in l_test_index.items()}
|
# l_test_index = {l: train[:5] for l, train in l_test_index.items()}
|
||||||
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||||
|
|
@ -167,7 +181,7 @@ class BertDataModule(RecurrentDataModule):
|
||||||
Pytorch Lightning Datamodule to be deployed with BertGen.
|
Pytorch Lightning Datamodule to be deployed with BertGen.
|
||||||
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
"""
|
"""
|
||||||
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
|
def __init__(self, multilingualIndex, batchsize=64, max_len=512, zero_shot=False, zscl_langs=None):
|
||||||
"""
|
"""
|
||||||
Init BertDataModule.
|
Init BertDataModule.
|
||||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
|
@ -177,10 +191,18 @@ class BertDataModule(RecurrentDataModule):
|
||||||
"""
|
"""
|
||||||
super().__init__(multilingualIndex, batchsize)
|
super().__init__(multilingualIndex, batchsize)
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
# Zero shot arguments
|
||||||
|
if zscl_langs is None:
|
||||||
|
zscl_langs = []
|
||||||
|
self.zero_shot = zero_shot
|
||||||
|
self.train_langs = zscl_langs
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
if self.zero_shot:
|
||||||
|
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw_zero_shot(langs=self.train_langs) # todo: check this!
|
||||||
|
else:
|
||||||
|
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
# l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
# l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
||||||
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||||
|
|
@ -189,7 +211,10 @@ class BertDataModule(RecurrentDataModule):
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
if self.zero_shot:
|
||||||
|
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw_zero_shot(langs=self.train_langs) # todo: check this!
|
||||||
|
else:
|
||||||
|
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
# l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
# l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
||||||
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||||
|
|
@ -199,7 +224,10 @@ class BertDataModule(RecurrentDataModule):
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
if self.zero_shot:
|
||||||
|
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw_zero_shot(langs=self.train_langs) # todo: check this!
|
||||||
|
else:
|
||||||
|
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
# l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
# l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
||||||
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||||
|
|
|
||||||
|
|
@ -149,33 +149,60 @@ class MultilingualIndex:
|
||||||
def l_train_index(self):
|
def l_train_index(self):
|
||||||
return {l: index.train_index for l, index in self.l_index.items()}
|
return {l: index.train_index for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_train_index_zero_shot(self, langs):
|
||||||
|
return {l: index.train_index for l, index in self.l_index.items() if l in langs}
|
||||||
|
|
||||||
def l_train_raw_index(self):
|
def l_train_raw_index(self):
|
||||||
return {l: index.train_raw for l, index in self.l_index.items()}
|
return {l: index.train_raw for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_train_raw_index_zero_shot(self, langs):
|
||||||
|
return {l: index.train_raw for l, index in self.l_index.items() if l in langs}
|
||||||
|
|
||||||
def l_train_target(self):
|
def l_train_target(self):
|
||||||
return {l: index.train_target for l, index in self.l_index.items()}
|
return {l: index.train_target for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_train_target_zero_shot(self, langs):
|
||||||
|
return {l: index.train_target for l, index in self.l_index.items() if l in langs}
|
||||||
|
|
||||||
def l_val_index(self):
|
def l_val_index(self):
|
||||||
return {l: index.val_index for l, index in self.l_index.items()}
|
return {l: index.val_index for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_val_index_zero_shot(self, langs):
|
||||||
|
return {l: index.val_index for l, index in self.l_index.items() if l in langs}
|
||||||
|
|
||||||
def l_val_raw_index(self):
|
def l_val_raw_index(self):
|
||||||
return {l: index.val_raw for l, index in self.l_index.items()}
|
return {l: index.val_raw for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_val_raw_index_zero_shot(self, langs):
|
||||||
|
return {l: index.val_raw for l, index in self.l_index.items() if l in langs}
|
||||||
|
|
||||||
def l_test_raw_index(self):
|
def l_test_raw_index(self):
|
||||||
return {l: index.test_raw for l, index in self.l_index.items()}
|
return {l: index.test_raw for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_test_raw_index_zero_shot(self, langs):
|
||||||
|
return {l: index.test_raw for l, index in self.l_index.items() for l in langs}
|
||||||
|
|
||||||
def l_devel_raw_index(self):
|
def l_devel_raw_index(self):
|
||||||
return {l: index.devel_raw for l, index in self.l_index.items()}
|
return {l: index.devel_raw for l, index in self.l_index.items()}
|
||||||
|
|
||||||
def l_val_target(self):
|
def l_val_target(self):
|
||||||
return {l: index.val_target for l, index in self.l_index.items()}
|
return {l: index.val_target for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_val_target_zero_shot(self, langs):
|
||||||
|
return {l: index.val_target for l, index in self.l_index.items() if l in langs}
|
||||||
|
|
||||||
def l_test_target(self):
|
def l_test_target(self):
|
||||||
return {l: index.test_target for l, index in self.l_index.items()}
|
return {l: index.test_target for l, index in self.l_index.items()}
|
||||||
|
|
||||||
def l_test_index(self):
|
def l_test_index(self):
|
||||||
return {l: index.test_index for l, index in self.l_index.items()}
|
return {l: index.test_index for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_test_target_zero_shot(self, langs):
|
||||||
|
return {l: index.test_target for l, index in self.l_index.items() if l in langs}
|
||||||
|
|
||||||
|
def l_test_index_zero_shot(self, langs):
|
||||||
|
return {l: index.test_index for l, index in self.l_index.items() if l in langs}
|
||||||
|
|
||||||
def l_devel_index(self):
|
def l_devel_index(self):
|
||||||
return {l: index.devel_index for l, index in self.l_index.items()}
|
return {l: index.devel_index for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
|
@ -191,15 +218,33 @@ class MultilingualIndex:
|
||||||
def l_test(self):
|
def l_test(self):
|
||||||
return self.l_test_index(), self.l_test_target()
|
return self.l_test_index(), self.l_test_target()
|
||||||
|
|
||||||
|
def l_test_zero_shot(self, langs):
|
||||||
|
return self.l_test_index_zero_shot(langs), self.l_test_target_zero_shot(langs)
|
||||||
|
|
||||||
|
def l_train_zero_shot(self, langs):
|
||||||
|
return self.l_train_index_zero_shot(langs), self.l_train_target_zero_shot(langs)
|
||||||
|
|
||||||
|
def l_val_zero_shot(self, langs):
|
||||||
|
return self.l_val_index_zero_shot(langs), self.l_val_target_zero_shot(langs)
|
||||||
|
|
||||||
def l_train_raw(self):
|
def l_train_raw(self):
|
||||||
return self.l_train_raw_index(), self.l_train_target()
|
return self.l_train_raw_index(), self.l_train_target()
|
||||||
|
|
||||||
|
def l_train_raw_zero_shot(self, langs):
|
||||||
|
return self.l_train_raw_index_zero_shot(langs), self.l_train_target_zero_shot(langs)
|
||||||
|
|
||||||
def l_val_raw(self):
|
def l_val_raw(self):
|
||||||
return self.l_val_raw_index(), self.l_val_target()
|
return self.l_val_raw_index(), self.l_val_target()
|
||||||
|
|
||||||
|
def l_val_raw_zero_shot(self, langs):
|
||||||
|
return self.l_val_raw_index_zero_shot(langs), self.l_val_target_zero_shot(langs)
|
||||||
|
|
||||||
def l_test_raw(self):
|
def l_test_raw(self):
|
||||||
return self.l_test_raw_index(), self.l_test_target()
|
return self.l_test_raw_index(), self.l_test_target()
|
||||||
|
|
||||||
|
def l_test_raw_zero_shot(self, langs):
|
||||||
|
return self.l_test_raw_index_zero_shot(langs), self.l_test_target_zero_shot(langs)
|
||||||
|
|
||||||
def l_devel_raw(self):
|
def l_devel_raw(self):
|
||||||
return self.l_devel_raw_index(), self.l_devel_target()
|
return self.l_devel_raw_index(), self.l_devel_target()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -259,7 +259,7 @@ class RecurrentGen(ViewGen):
|
||||||
the network internal state at the second feed-forward layer level. Training metrics are logged via TensorBoard.
|
the network internal state at the second feed-forward layer level. Training metrics are logged via TensorBoard.
|
||||||
"""
|
"""
|
||||||
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
|
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
|
||||||
gpus=0, n_jobs=-1, patience=20, stored_path=None):
|
gpus=0, n_jobs=-1, patience=20, stored_path=None, zero_shot=False, train_langs: list = None):
|
||||||
"""
|
"""
|
||||||
Init RecurrentGen.
|
Init RecurrentGen.
|
||||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
|
@ -298,6 +298,12 @@ class RecurrentGen(ViewGen):
|
||||||
patience=self.patience, verbose=False, mode='max')
|
patience=self.patience, verbose=False, mode='max')
|
||||||
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||||
|
|
||||||
|
# Zero shot parameters
|
||||||
|
self.zero_shot = zero_shot
|
||||||
|
if train_langs is None:
|
||||||
|
train_langs = ['it']
|
||||||
|
self.train_langs = train_langs
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
if self.stored_path:
|
if self.stored_path:
|
||||||
lpretrained = self.multilingualIndex.l_embeddings()
|
lpretrained = self.multilingualIndex.l_embeddings()
|
||||||
|
|
@ -332,7 +338,8 @@ class RecurrentGen(ViewGen):
|
||||||
"""
|
"""
|
||||||
print('# Fitting RecurrentGen (G)...')
|
print('# Fitting RecurrentGen (G)...')
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
|
||||||
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs) # Todo: zero shot settings
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||||
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
|
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
|
||||||
|
|
||||||
|
|
@ -343,6 +350,9 @@ class RecurrentGen(ViewGen):
|
||||||
# self.model.linear2 = vanilla_torch_model.linear2
|
# self.model.linear2 = vanilla_torch_model.linear2
|
||||||
# self.model.rnn = vanilla_torch_model.rnn
|
# self.model.rnn = vanilla_torch_model.rnn
|
||||||
|
|
||||||
|
if self.zero_shot: # Todo: zero shot experiment setting
|
||||||
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
|
||||||
trainer.fit(self.model, datamodule=recurrentDataModule)
|
trainer.fit(self.model, datamodule=recurrentDataModule)
|
||||||
trainer.test(self.model, datamodule=recurrentDataModule)
|
trainer.test(self.model, datamodule=recurrentDataModule)
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue