logging dirs

This commit is contained in:
andrea 2021-01-29 10:49:47 +01:00
parent e52b153ad4
commit 2c70f37823
1 changed files with 4 additions and 4 deletions

View File

@ -237,7 +237,7 @@ class RecurrentGen(ViewGen):
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
self.model = self._init_model()
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False)
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max')
@ -274,7 +274,7 @@ class RecurrentGen(ViewGen):
:return: self.
"""
print('# Fitting RecurrentGen (G)...')
create_if_not_exist('../tb_logs')
create_if_not_exist(self.logger.save_dir)
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
callbacks=[self.early_stop_callback], checkpoint_callback=False)
@ -342,7 +342,7 @@ class BertGen(ViewGen):
self.stored_path = stored_path
self.model = self._init_model()
self.patience = patience
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max')
@ -360,7 +360,7 @@ class BertGen(ViewGen):
:return: self.
"""
print('# Fitting BertGen (M)...')
create_if_not_exist('../tb_logs')
create_if_not_exist(self.logger.save_dir)
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,