setting up zero-shot experiments (implemented for Recurrent and Bert but not tested)
This commit is contained in:
parent
ee98c5f610
commit
6361a4eba0
|
|
@ -383,7 +383,6 @@ class RecurrentGen(ViewGen):
|
||||||
|
|
||||||
def set_zero_shot(self, val: bool):
|
def set_zero_shot(self, val: bool):
|
||||||
self.zero_shot = val
|
self.zero_shot = val
|
||||||
print('# TODO: RecurrentGen has not been configured for zero-shot experiments')
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -393,7 +392,8 @@ class BertGen(ViewGen):
|
||||||
At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document
|
At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document
|
||||||
embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard.
|
embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard.
|
||||||
"""
|
"""
|
||||||
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None):
|
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None,
|
||||||
|
zero_shot=False, train_langs: list = None):
|
||||||
"""
|
"""
|
||||||
Init Bert model
|
Init Bert model
|
||||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
|
@ -418,6 +418,12 @@ class BertGen(ViewGen):
|
||||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
patience=self.patience, verbose=False, mode='max')
|
patience=self.patience, verbose=False, mode='max')
|
||||||
|
|
||||||
|
# Zero shot parameters
|
||||||
|
self.zero_shot = zero_shot
|
||||||
|
if train_langs is None:
|
||||||
|
train_langs = ['it']
|
||||||
|
self.train_langs = train_langs
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
output_size = self.multilingualIndex.get_target_dim()
|
output_size = self.multilingualIndex.get_target_dim()
|
||||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||||
|
|
@ -434,7 +440,12 @@ class BertGen(ViewGen):
|
||||||
print('# Fitting BertGen (M)...')
|
print('# Fitting BertGen (M)...')
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||||
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||||
|
|
||||||
|
if self.zero_shot: # Todo: zero shot experiment setting
|
||||||
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||||
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||||
trainer.fit(self.model, datamodule=bertDataModule)
|
trainer.fit(self.model, datamodule=bertDataModule)
|
||||||
|
|
@ -459,5 +470,4 @@ class BertGen(ViewGen):
|
||||||
|
|
||||||
def set_zero_shot(self, val: bool):
|
def set_zero_shot(self, val: bool):
|
||||||
self.zero_shot = val
|
self.zero_shot = val
|
||||||
print('# TODO: BertGen has not been configured for zero-shot experiments')
|
|
||||||
return
|
return
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue