setting up zero-shot experiments (implemented for Recurrent and Bert but not tested)

This commit is contained in:
andrea 2021-02-02 16:12:08 +01:00
parent ee98c5f610
commit 6361a4eba0
1 changed files with 14 additions and 4 deletions

View File

@ -383,7 +383,6 @@ class RecurrentGen(ViewGen):
def set_zero_shot(self, val: bool): def set_zero_shot(self, val: bool):
self.zero_shot = val self.zero_shot = val
print('# TODO: RecurrentGen has not been configured for zero-shot experiments')
return return
@ -393,7 +392,8 @@ class BertGen(ViewGen):
At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document
embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard. embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard.
""" """
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None): def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None,
zero_shot=False, train_langs: list = None):
""" """
Init Bert model Init Bert model
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
@ -418,6 +418,12 @@ class BertGen(ViewGen):
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max') patience=self.patience, verbose=False, mode='max')
# Zero shot parameters
self.zero_shot = zero_shot
if train_langs is None:
train_langs = ['it']
self.train_langs = train_langs
def _init_model(self): def _init_model(self):
output_size = self.multilingualIndex.get_target_dim() output_size = self.multilingualIndex.get_target_dim()
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus) return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
@ -434,7 +440,12 @@ class BertGen(ViewGen):
print('# Fitting BertGen (M)...') print('# Fitting BertGen (M)...')
create_if_not_exist(self.logger.save_dir) create_if_not_exist(self.logger.save_dir)
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
if self.zero_shot: # Todo: zero shot experiment setting
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus, trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False) logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
trainer.fit(self.model, datamodule=bertDataModule) trainer.fit(self.model, datamodule=bertDataModule)
@ -459,5 +470,4 @@ class BertGen(ViewGen):
def set_zero_shot(self, val: bool): def set_zero_shot(self, val: bool):
self.zero_shot = val self.zero_shot = val
print('# TODO: BertGen has not been configured for zero-shot experiments')
return return