model checkpoint during training. Restore best model if earlystop is triggered

This commit is contained in:
Andrea Pedrotti 2023-02-10 11:37:32 +01:00
parent 9c2c43dafb
commit 3f3e4982e4
6 changed files with 49 additions and 12 deletions

View File

@ -34,6 +34,7 @@ class GeneralizedFunnelling:
optimc, optimc,
device, device,
load_trained, load_trained,
dataset_name,
): ):
# Setting VFGs ----------- # Setting VFGs -----------
self.posteriors_vgf = posterior self.posteriors_vgf = posterior
@ -63,6 +64,7 @@ class GeneralizedFunnelling:
self.metaclassifier = None self.metaclassifier = None
self.aggfunc = "mean" self.aggfunc = "mean"
self.load_trained = load_trained self.load_trained = load_trained
self.dataset_name = dataset_name
self._init() self._init()
def _init(self): def _init(self):
@ -109,6 +111,7 @@ class GeneralizedFunnelling:
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
verbose=True, verbose=True,
patience=self.patience, patience=self.patience,
dataset_name=self.dataset_name,
) )
self.first_tier_learners.append(transformer_vgf) self.first_tier_learners.append(transformer_vgf)

View File

@ -114,9 +114,11 @@ class Trainer:
self.evaluate_steps = evaluate_step self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device) self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps self.print_steps = print_steps
self.experiment_name = experiment_name
self.patience = patience
self.earlystopping = EarlyStopping( self.earlystopping = EarlyStopping(
patience=patience, patience=patience,
checkpoint_path="models/vgfs/transformers/", checkpoint_path="models/vgfs/transformer/",
verbose=True, verbose=True,
experiment_name=experiment_name, experiment_name=experiment_name,
) )
@ -129,12 +131,13 @@ class Trainer:
def train(self, train_dataloader, eval_dataloader, epochs=10): def train(self, train_dataloader, eval_dataloader, epochs=10):
print( print(
f"""- Training params: f"""- Training params for {self.experiment_name}:
- epochs: {epochs} - epochs: {epochs}
- learning rate: {self.optimizer.defaults['lr']} - learning rate: {self.optimizer.defaults['lr']}
- train batch size: {train_dataloader.batch_size} - train batch size: {train_dataloader.batch_size}
- eval batch size: {eval_dataloader.batch_size} - eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}\n""", - max len: {train_dataloader.dataset.X.shape[-1]}
- patience: {self.earlystopping.patience}\n"""
) )
for epoch in range(epochs): for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch) self.train_epoch(train_dataloader, epoch)
@ -142,7 +145,17 @@ class Trainer:
metric_watcher = self.evaluate(eval_dataloader) metric_watcher = self.evaluate(eval_dataloader)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1) stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop: if stop:
print(
f"- restoring best model from epoch {self.earlystopping.best_epoch}"
)
self.model = self.earlystopping.load_model(self.model).to(
self.device
)
break break
# TODO: maybe a lower lr?
self.train_epoch(eval_dataloader, epoch=epoch)
print(f"\n- last swipe on eval set")
self.earlystopping.save_model(self.model)
return self.model return self.model
def train_epoch(self, dataloader, epoch): def train_epoch(self, dataloader, epoch):
@ -182,13 +195,14 @@ class Trainer:
class EarlyStopping: class EarlyStopping:
# TODO: add checkpointing + restore model if early stopping + last swipe on validation set
def __init__( def __init__(
self, self,
patience=5, patience,
checkpoint_path,
experiment_name,
min_delta=0, min_delta=0,
verbose=True, verbose=True,
checkpoint_path="checkpoint.pt",
experiment_name="experiment",
): ):
self.patience = patience self.patience = patience
self.min_delta = min_delta self.min_delta = min_delta
@ -206,7 +220,8 @@ class EarlyStopping:
) )
self.best_score = validation self.best_score = validation
self.counter = 0 self.counter = 0
# self.save_model(model) self.best_epoch = epoch
self.save_model(model)
elif validation < (self.best_score + self.min_delta): elif validation < (self.best_score + self.min_delta):
self.counter += 1 self.counter += 1
print( print(
@ -219,6 +234,9 @@ class EarlyStopping:
def save_model(self, model): def save_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
print(f"- saving model to {_checkpoint_dir}")
os.makedirs(_checkpoint_dir, exist_ok=True) os.makedirs(_checkpoint_dir, exist_ok=True)
model.save_pretrained(_checkpoint_dir) model.save_pretrained(_checkpoint_dir)
def load_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
return model.from_pretrained(_checkpoint_dir)

View File

@ -27,6 +27,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
def __init__( def __init__(
self, self,
model_name, model_name,
dataset_name,
epochs=10, epochs=10,
lr=1e-5, lr=1e-5,
batch_size=4, batch_size=4,
@ -42,6 +43,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
): ):
super().__init__( super().__init__(
model_name, model_name,
dataset_name,
epochs, epochs,
lr, lr,
batch_size, batch_size,
@ -135,7 +137,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
shuffle=False, shuffle=False,
) )
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer( trainer = Trainer(
model=self.model, model=self.model,
optimizer_name="adamW", optimizer_name="adamW",

View File

@ -12,6 +12,7 @@ class TransformerGen:
def __init__( def __init__(
self, self,
model_name, model_name,
dataset_name,
epochs=10, epochs=10,
lr=1e-5, lr=1e-5,
batch_size=4, batch_size=4,
@ -26,6 +27,7 @@ class TransformerGen:
patience=5, patience=5,
): ):
self.model_name = model_name self.model_name = model_name
self.dataset_name = dataset_name
self.device = device self.device = device
self.model = None self.model = None
self.lr = lr self.lr = lr
@ -44,6 +46,9 @@ class TransformerGen:
self.patience = patience self.patience = patience
self.datasets = {} self.datasets = {}
def make_probabilistic(self):
raise NotImplementedError
def build_dataloader( def build_dataloader(
self, self,
lX, lX,

View File

@ -16,9 +16,11 @@ transformers.logging.set_verbosity_error()
class VisualTransformerGen(ViewGen, TransformerGen): class VisualTransformerGen(ViewGen, TransformerGen):
# TODO: probabilistic behaviour
def __init__( def __init__(
self, self,
model_name, model_name,
dataset_name,
lr=1e-5, lr=1e-5,
epochs=10, epochs=10,
batch_size=32, batch_size=32,
@ -29,6 +31,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
): ):
super().__init__( super().__init__(
model_name, model_name,
dataset_name,
lr=lr, lr=lr,
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,

10
main.py
View File

@ -13,6 +13,9 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
TODO: TODO:
- add documentations sphinx - add documentations sphinx
- zero-shot setup - zero-shot setup
- set probabilistic behaviour in Transformer parent-class
- pooling / attention aggregation
- test split in MultiNews dataset
""" """
@ -38,7 +41,7 @@ def get_dataset(datasetname):
dataset = ( dataset = (
MultilingualDataset(dataset_name="rcv1-2") MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH) .load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=100) .reduce_data(langs=["en", "it", "fr"], maxn=500)
) )
else: else:
raise NotImplementedError raise NotImplementedError
@ -52,6 +55,7 @@ def main(args):
): ):
lX, lY = dataset.training() lX, lY = dataset.training()
# lX_te, lY_te = dataset.test() # lX_te, lY_te = dataset.test()
print("[NB: for debug purposes, training set is also used as test set]\n")
lX_te, lY_te = dataset.training() lX_te, lY_te = dataset.training()
else: else:
_lX = dataset.dX _lX = dataset.dX
@ -71,6 +75,7 @@ def main(args):
), "At least one of VGF must be True" ), "At least one of VGF must be True"
gfun = GeneralizedFunnelling( gfun = GeneralizedFunnelling(
dataset_name=args.dataset,
posterior=args.posteriors, posterior=args.posteriors,
multilingual=args.multilingual, multilingual=args.multilingual,
wce=args.wce, wce=args.wce,
@ -93,8 +98,7 @@ def main(args):
# gfun.get_config() # gfun.get_config()
gfun.fit(lX, lY) gfun.fit(lX, lY)
if args.load_trained is None: if args.load_trained is not None:
print("[NB: FORCE-SKIPPING MODEL SAVE]")
gfun.save() gfun.save()
# if not args.load_model: # if not args.load_model: