diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index b1f01cd..e6527cd 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -243,9 +243,12 @@ class Trainer: print( f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}" ) - self.model = self.earlystopping.load_model(self.model).to( - self.device - ) + restored_model = self.earlystopping.load_model(self.model) + + # swapping model on gpu + del self.model + self.model = restored_model.to(self.device) + break if self.scheduler is not None: