logging dirs
This commit is contained in:
parent
e52b153ad4
commit
2c70f37823
|
|
@ -237,7 +237,7 @@ class RecurrentGen(ViewGen):
|
|||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
|
||||
self.model = self._init_model()
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False)
|
||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
|
||||
|
|
@ -274,7 +274,7 @@ class RecurrentGen(ViewGen):
|
|||
:return: self.
|
||||
"""
|
||||
print('# Fitting RecurrentGen (G)...')
|
||||
create_if_not_exist('../tb_logs')
|
||||
create_if_not_exist(self.logger.save_dir)
|
||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||
callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||
|
|
@ -342,7 +342,7 @@ class BertGen(ViewGen):
|
|||
self.stored_path = stored_path
|
||||
self.model = self._init_model()
|
||||
self.patience = patience
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
|
||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
|
||||
|
|
@ -360,7 +360,7 @@ class BertGen(ViewGen):
|
|||
:return: self.
|
||||
"""
|
||||
print('# Fitting BertGen (M)...')
|
||||
create_if_not_exist('../tb_logs')
|
||||
create_if_not_exist(self.logger.save_dir)
|
||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||
|
|
|
|||
Loading…
Reference in New Issue