From 3f3e4982e45d4e41903a24b6a049bc54f515910a Mon Sep 17 00:00:00 2001 From: andreapdr Date: Fri, 10 Feb 2023 11:37:32 +0100 Subject: [PATCH] model checkpoint during training. Restore best model if earlystop is triggered --- gfun/generalizedFunnelling.py | 3 +++ gfun/vgfs/commons.py | 34 +++++++++++++++++++++++------- gfun/vgfs/textualTransformerGen.py | 6 +++++- gfun/vgfs/transformerGen.py | 5 +++++ gfun/vgfs/visualTransformerGen.py | 3 +++ main.py | 10 ++++++--- 6 files changed, 49 insertions(+), 12 deletions(-) diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 091e9d4..2da77fb 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -34,6 +34,7 @@ class GeneralizedFunnelling: optimc, device, load_trained, + dataset_name, ): # Setting VFGs ----------- self.posteriors_vgf = posterior @@ -63,6 +64,7 @@ class GeneralizedFunnelling: self.metaclassifier = None self.aggfunc = "mean" self.load_trained = load_trained + self.dataset_name = dataset_name self._init() def _init(self): @@ -109,6 +111,7 @@ class GeneralizedFunnelling: evaluate_step=self.evaluate_step, verbose=True, patience=self.patience, + dataset_name=self.dataset_name, ) self.first_tier_learners.append(transformer_vgf) diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 96a485d..4cdf04e 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -114,9 +114,11 @@ class Trainer: self.evaluate_steps = evaluate_step self.loss_fn = loss_fn.to(device) self.print_steps = print_steps + self.experiment_name = experiment_name + self.patience = patience self.earlystopping = EarlyStopping( patience=patience, - checkpoint_path="models/vgfs/transformers/", + checkpoint_path="models/vgfs/transformer/", verbose=True, experiment_name=experiment_name, ) @@ -129,12 +131,13 @@ class Trainer: def train(self, train_dataloader, eval_dataloader, epochs=10): print( - f"""- Training params: + f"""- Training params for {self.experiment_name}: - epochs: {epochs} - learning rate: {self.optimizer.defaults['lr']} - train batch size: {train_dataloader.batch_size} - eval batch size: {eval_dataloader.batch_size} - - max len: {train_dataloader.dataset.X.shape[-1]}\n""", + - max len: {train_dataloader.dataset.X.shape[-1]} + - patience: {self.earlystopping.patience}\n""" ) for epoch in range(epochs): self.train_epoch(train_dataloader, epoch) @@ -142,7 +145,17 @@ class Trainer: metric_watcher = self.evaluate(eval_dataloader) stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: + print( + f"- restoring best model from epoch {self.earlystopping.best_epoch}" + ) + self.model = self.earlystopping.load_model(self.model).to( + self.device + ) break + # TODO: maybe a lower lr? + self.train_epoch(eval_dataloader, epoch=epoch) + print(f"\n- last swipe on eval set") + self.earlystopping.save_model(self.model) return self.model def train_epoch(self, dataloader, epoch): @@ -182,13 +195,14 @@ class Trainer: class EarlyStopping: + # TODO: add checkpointing + restore model if early stopping + last swipe on validation set def __init__( self, - patience=5, + patience, + checkpoint_path, + experiment_name, min_delta=0, verbose=True, - checkpoint_path="checkpoint.pt", - experiment_name="experiment", ): self.patience = patience self.min_delta = min_delta @@ -206,7 +220,8 @@ class EarlyStopping: ) self.best_score = validation self.counter = 0 - # self.save_model(model) + self.best_epoch = epoch + self.save_model(model) elif validation < (self.best_score + self.min_delta): self.counter += 1 print( @@ -219,6 +234,9 @@ class EarlyStopping: def save_model(self, model): _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) - print(f"- saving model to {_checkpoint_dir}") os.makedirs(_checkpoint_dir, exist_ok=True) model.save_pretrained(_checkpoint_dir) + + def load_model(self, model): + _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) + return model.from_pretrained(_checkpoint_dir) diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 7648fb9..77f8960 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -27,6 +27,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): def __init__( self, model_name, + dataset_name, epochs=10, lr=1e-5, batch_size=4, @@ -42,6 +43,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): ): super().__init__( model_name, + dataset_name, epochs, lr, batch_size, @@ -135,7 +137,9 @@ class TextualTransformerGen(ViewGen, TransformerGen): shuffle=False, ) - experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" + experiment_name = ( + f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}" + ) trainer = Trainer( model=self.model, optimizer_name="adamW", diff --git a/gfun/vgfs/transformerGen.py b/gfun/vgfs/transformerGen.py index 9c56451..2b38045 100644 --- a/gfun/vgfs/transformerGen.py +++ b/gfun/vgfs/transformerGen.py @@ -12,6 +12,7 @@ class TransformerGen: def __init__( self, model_name, + dataset_name, epochs=10, lr=1e-5, batch_size=4, @@ -26,6 +27,7 @@ class TransformerGen: patience=5, ): self.model_name = model_name + self.dataset_name = dataset_name self.device = device self.model = None self.lr = lr @@ -44,6 +46,9 @@ class TransformerGen: self.patience = patience self.datasets = {} + def make_probabilistic(self): + raise NotImplementedError + def build_dataloader( self, lX, diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index 80f3682..7bd0497 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -16,9 +16,11 @@ transformers.logging.set_verbosity_error() class VisualTransformerGen(ViewGen, TransformerGen): + # TODO: probabilistic behaviour def __init__( self, model_name, + dataset_name, lr=1e-5, epochs=10, batch_size=32, @@ -29,6 +31,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): ): super().__init__( model_name, + dataset_name, lr=lr, epochs=epochs, batch_size=batch_size, diff --git a/main.py b/main.py index 783e58a..81df9ef 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,9 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling TODO: - add documentations sphinx - zero-shot setup + - set probabilistic behaviour in Transformer parent-class + - pooling / attention aggregation + - test split in MultiNews dataset """ @@ -38,7 +41,7 @@ def get_dataset(datasetname): dataset = ( MultilingualDataset(dataset_name="rcv1-2") .load(RCV_DATAPATH) - .reduce_data(langs=["en", "it", "fr"], maxn=100) + .reduce_data(langs=["en", "it", "fr"], maxn=500) ) else: raise NotImplementedError @@ -52,6 +55,7 @@ def main(args): ): lX, lY = dataset.training() # lX_te, lY_te = dataset.test() + print("[NB: for debug purposes, training set is also used as test set]\n") lX_te, lY_te = dataset.training() else: _lX = dataset.dX @@ -71,6 +75,7 @@ def main(args): ), "At least one of VGF must be True" gfun = GeneralizedFunnelling( + dataset_name=args.dataset, posterior=args.posteriors, multilingual=args.multilingual, wce=args.wce, @@ -93,8 +98,7 @@ def main(args): # gfun.get_config() gfun.fit(lX, lY) - if args.load_trained is None: - print("[NB: FORCE-SKIPPING MODEL SAVE]") + if args.load_trained is not None: gfun.save() # if not args.load_model: