diff --git a/src/view_generators.py b/src/view_generators.py index f1e3eb6..a690e8f 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -383,7 +383,6 @@ class RecurrentGen(ViewGen): def set_zero_shot(self, val: bool): self.zero_shot = val - print('# TODO: RecurrentGen has not been configured for zero-shot experiments') return @@ -393,7 +392,8 @@ class BertGen(ViewGen): At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard. """ - def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None): + def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None, + zero_shot=False, train_langs: list = None): """ Init Bert model :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents @@ -418,6 +418,12 @@ class BertGen(ViewGen): self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') + # Zero shot parameters + self.zero_shot = zero_shot + if train_langs is None: + train_langs = ['it'] + self.train_langs = train_langs + def _init_model(self): output_size = self.multilingualIndex.get_target_dim() return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus) @@ -434,7 +440,12 @@ class BertGen(ViewGen): print('# Fitting BertGen (M)...') create_if_not_exist(self.logger.save_dir) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) - bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512) + bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512, + zero_shot=self.zero_shot, zscl_langs=self.train_langs) + + if self.zero_shot: # Todo: zero shot experiment setting + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') + trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus, logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False) trainer.fit(self.model, datamodule=bertDataModule) @@ -459,5 +470,4 @@ class BertGen(ViewGen): def set_zero_shot(self, val: bool): self.zero_shot = val - print('# TODO: BertGen has not been configured for zero-shot experiments') return