From 2c70f378239e522a830fa4b74116b38d545dac7a Mon Sep 17 00:00:00 2001 From: andrea Date: Fri, 29 Jan 2021 10:49:47 +0100 Subject: [PATCH] logging dirs --- src/view_generators.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/view_generators.py b/src/view_generators.py index fab56c7..27da0fc 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -237,7 +237,7 @@ class RecurrentGen(ViewGen): self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce) self.model = self._init_model() - self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False) + self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False) self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') @@ -274,7 +274,7 @@ class RecurrentGen(ViewGen): :return: self. """ print('# Fitting RecurrentGen (G)...') - create_if_not_exist('../tb_logs') + create_if_not_exist(self.logger.save_dir) recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs) trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs, callbacks=[self.early_stop_callback], checkpoint_callback=False) @@ -342,7 +342,7 @@ class BertGen(ViewGen): self.stored_path = stored_path self.model = self._init_model() self.patience = patience - self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False) + self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False) self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') @@ -360,7 +360,7 @@ class BertGen(ViewGen): :return: self. """ print('# Fitting BertGen (M)...') - create_if_not_exist('../tb_logs') + create_if_not_exist(self.logger.save_dir) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512) trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,