setting up zero-shot experiments (implemented for Recurrent and Bert but not tested)
This commit is contained in:
parent
ee98c5f610
commit
6361a4eba0
|
|
@ -383,7 +383,6 @@ class RecurrentGen(ViewGen):
|
|||
|
||||
def set_zero_shot(self, val: bool):
|
||||
self.zero_shot = val
|
||||
print('# TODO: RecurrentGen has not been configured for zero-shot experiments')
|
||||
return
|
||||
|
||||
|
||||
|
|
@ -393,7 +392,8 @@ class BertGen(ViewGen):
|
|||
At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document
|
||||
embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard.
|
||||
"""
|
||||
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None):
|
||||
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None,
|
||||
zero_shot=False, train_langs: list = None):
|
||||
"""
|
||||
Init Bert model
|
||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||
|
|
@ -418,6 +418,12 @@ class BertGen(ViewGen):
|
|||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
|
||||
# Zero shot parameters
|
||||
self.zero_shot = zero_shot
|
||||
if train_langs is None:
|
||||
train_langs = ['it']
|
||||
self.train_langs = train_langs
|
||||
|
||||
def _init_model(self):
|
||||
output_size = self.multilingualIndex.get_target_dim()
|
||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||
|
|
@ -434,7 +440,12 @@ class BertGen(ViewGen):
|
|||
print('# Fitting BertGen (M)...')
|
||||
create_if_not_exist(self.logger.save_dir)
|
||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||
|
||||
if self.zero_shot: # Todo: zero shot experiment setting
|
||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||
|
||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||
trainer.fit(self.model, datamodule=bertDataModule)
|
||||
|
|
@ -459,5 +470,4 @@ class BertGen(ViewGen):
|
|||
|
||||
def set_zero_shot(self, val: bool):
|
||||
self.zero_shot = val
|
||||
print('# TODO: BertGen has not been configured for zero-shot experiments')
|
||||
return
|
||||
|
|
|
|||
Loading…
Reference in New Issue