diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 79796f5..7c8b494 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -246,7 +246,6 @@ class Trainer: # swapping model on gpu del self.model self.model = restored_model.to(self.device) - break if self.scheduler is not None: @@ -262,7 +261,14 @@ class Trainer: ) print(f"- last swipe on eval set") - self.train_epoch(eval_dataloader, epoch=-1) + self.train_epoch( + DataLoader( + eval_dataloader.dataset, + batch_size=train_dataloader.batch_size, + shuffle=True, + ), + epoch=-1, + ) self.earlystopping.save_model(self.model) return self.model @@ -341,6 +347,7 @@ class EarlyStopping: def __call__(self, validation, model, epoch): if validation >= self.best_score: + wandb.log({"patience": self.patience - self.counter}) if self.verbose: print( f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" @@ -352,6 +359,7 @@ class EarlyStopping: self.save_model(model) elif validation < (self.best_score + self.min_delta): self.counter += 1 + wandb.log({"patience": self.patience - self.counter}) if self.verbose: print( f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"