setting up zero-shot experiments (implemented for Recurrent and Bert but not tested)

This commit is contained in:
andrea 2021-02-02 16:12:08 +01:00
parent ee98c5f610
commit 6361a4eba0
1 changed files with 14 additions and 4 deletions

View File

@ -383,7 +383,6 @@ class RecurrentGen(ViewGen):
def set_zero_shot(self, val: bool):
self.zero_shot = val
print('# TODO: RecurrentGen has not been configured for zero-shot experiments')
return
@ -393,7 +392,8 @@ class BertGen(ViewGen):
At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document
embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard.
"""
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None):
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None,
zero_shot=False, train_langs: list = None):
"""
Init Bert model
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
@ -418,6 +418,12 @@ class BertGen(ViewGen):
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max')
# Zero shot parameters
self.zero_shot = zero_shot
if train_langs is None:
train_langs = ['it']
self.train_langs = train_langs
def _init_model(self):
output_size = self.multilingualIndex.get_target_dim()
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
@ -434,7 +440,12 @@ class BertGen(ViewGen):
print('# Fitting BertGen (M)...')
create_if_not_exist(self.logger.save_dir)
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
if self.zero_shot: # Todo: zero shot experiment setting
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
trainer.fit(self.model, datamodule=bertDataModule)
@ -459,5 +470,4 @@ class BertGen(ViewGen):
def set_zero_shot(self, val: bool):
self.zero_shot = val
print('# TODO: BertGen has not been configured for zero-shot experiments')
return