From 6361a4eba080dae0298401c7ac1b07098aff6955 Mon Sep 17 00:00:00 2001 From: andrea Date: Tue, 2 Feb 2021 16:12:08 +0100 Subject: [PATCH] setting up zero-shot experiments (implemented for Recurrent and Bert but not tested) --- src/view_generators.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/view_generators.py b/src/view_generators.py index f1e3eb6..a690e8f 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -383,7 +383,6 @@ class RecurrentGen(ViewGen): def set_zero_shot(self, val: bool): self.zero_shot = val - print('# TODO: RecurrentGen has not been configured for zero-shot experiments') return @@ -393,7 +392,8 @@ class BertGen(ViewGen): At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard. """ - def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None): + def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None, + zero_shot=False, train_langs: list = None): """ Init Bert model :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents @@ -418,6 +418,12 @@ class BertGen(ViewGen): self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') + # Zero shot parameters + self.zero_shot = zero_shot + if train_langs is None: + train_langs = ['it'] + self.train_langs = train_langs + def _init_model(self): output_size = self.multilingualIndex.get_target_dim() return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus) @@ -434,7 +440,12 @@ class BertGen(ViewGen): print('# Fitting BertGen (M)...') create_if_not_exist(self.logger.save_dir) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) - bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512) + bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512, + zero_shot=self.zero_shot, zscl_langs=self.train_langs) + + if self.zero_shot: # Todo: zero shot experiment setting + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') + trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus, logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False) trainer.fit(self.model, datamodule=bertDataModule) @@ -459,5 +470,4 @@ class BertGen(ViewGen): def set_zero_shot(self, val: bool): self.zero_shot = val - print('# TODO: BertGen has not been configured for zero-shot experiments') return