model checkpoint during training. Restore best model if earlystop is triggered

This commit is contained in:
Andrea Pedrotti 2023-02-10 11:37:32 +01:00
parent 9c2c43dafb
commit 3f3e4982e4
6 changed files with 49 additions and 12 deletions

View File

@ -34,6 +34,7 @@ class GeneralizedFunnelling:
optimc,
device,
load_trained,
dataset_name,
):
# Setting VFGs -----------
self.posteriors_vgf = posterior
@ -63,6 +64,7 @@ class GeneralizedFunnelling:
self.metaclassifier = None
self.aggfunc = "mean"
self.load_trained = load_trained
self.dataset_name = dataset_name
self._init()
def _init(self):
@ -109,6 +111,7 @@ class GeneralizedFunnelling:
evaluate_step=self.evaluate_step,
verbose=True,
patience=self.patience,
dataset_name=self.dataset_name,
)
self.first_tier_learners.append(transformer_vgf)

View File

@ -114,9 +114,11 @@ class Trainer:
self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps
self.experiment_name = experiment_name
self.patience = patience
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path="models/vgfs/transformers/",
checkpoint_path="models/vgfs/transformer/",
verbose=True,
experiment_name=experiment_name,
)
@ -129,12 +131,13 @@ class Trainer:
def train(self, train_dataloader, eval_dataloader, epochs=10):
print(
f"""- Training params:
f"""- Training params for {self.experiment_name}:
- epochs: {epochs}
- learning rate: {self.optimizer.defaults['lr']}
- train batch size: {train_dataloader.batch_size}
- eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
- max len: {train_dataloader.dataset.X.shape[-1]}
- patience: {self.earlystopping.patience}\n"""
)
for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch)
@ -142,7 +145,17 @@ class Trainer:
metric_watcher = self.evaluate(eval_dataloader)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop:
print(
f"- restoring best model from epoch {self.earlystopping.best_epoch}"
)
self.model = self.earlystopping.load_model(self.model).to(
self.device
)
break
# TODO: maybe a lower lr?
self.train_epoch(eval_dataloader, epoch=epoch)
print(f"\n- last swipe on eval set")
self.earlystopping.save_model(self.model)
return self.model
def train_epoch(self, dataloader, epoch):
@ -182,13 +195,14 @@ class Trainer:
class EarlyStopping:
# TODO: add checkpointing + restore model if early stopping + last swipe on validation set
def __init__(
self,
patience=5,
patience,
checkpoint_path,
experiment_name,
min_delta=0,
verbose=True,
checkpoint_path="checkpoint.pt",
experiment_name="experiment",
):
self.patience = patience
self.min_delta = min_delta
@ -206,7 +220,8 @@ class EarlyStopping:
)
self.best_score = validation
self.counter = 0
# self.save_model(model)
self.best_epoch = epoch
self.save_model(model)
elif validation < (self.best_score + self.min_delta):
self.counter += 1
print(
@ -219,6 +234,9 @@ class EarlyStopping:
def save_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
print(f"- saving model to {_checkpoint_dir}")
os.makedirs(_checkpoint_dir, exist_ok=True)
model.save_pretrained(_checkpoint_dir)
def load_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
return model.from_pretrained(_checkpoint_dir)

View File

@ -27,6 +27,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
def __init__(
self,
model_name,
dataset_name,
epochs=10,
lr=1e-5,
batch_size=4,
@ -42,6 +43,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
):
super().__init__(
model_name,
dataset_name,
epochs,
lr,
batch_size,
@ -135,7 +137,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
shuffle=False,
)
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer(
model=self.model,
optimizer_name="adamW",

View File

@ -12,6 +12,7 @@ class TransformerGen:
def __init__(
self,
model_name,
dataset_name,
epochs=10,
lr=1e-5,
batch_size=4,
@ -26,6 +27,7 @@ class TransformerGen:
patience=5,
):
self.model_name = model_name
self.dataset_name = dataset_name
self.device = device
self.model = None
self.lr = lr
@ -44,6 +46,9 @@ class TransformerGen:
self.patience = patience
self.datasets = {}
def make_probabilistic(self):
raise NotImplementedError
def build_dataloader(
self,
lX,

View File

@ -16,9 +16,11 @@ transformers.logging.set_verbosity_error()
class VisualTransformerGen(ViewGen, TransformerGen):
# TODO: probabilistic behaviour
def __init__(
self,
model_name,
dataset_name,
lr=1e-5,
epochs=10,
batch_size=32,
@ -29,6 +31,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
):
super().__init__(
model_name,
dataset_name,
lr=lr,
epochs=epochs,
batch_size=batch_size,

10
main.py
View File

@ -13,6 +13,9 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
TODO:
- add documentations sphinx
- zero-shot setup
- set probabilistic behaviour in Transformer parent-class
- pooling / attention aggregation
- test split in MultiNews dataset
"""
@ -38,7 +41,7 @@ def get_dataset(datasetname):
dataset = (
MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=100)
.reduce_data(langs=["en", "it", "fr"], maxn=500)
)
else:
raise NotImplementedError
@ -52,6 +55,7 @@ def main(args):
):
lX, lY = dataset.training()
# lX_te, lY_te = dataset.test()
print("[NB: for debug purposes, training set is also used as test set]\n")
lX_te, lY_te = dataset.training()
else:
_lX = dataset.dX
@ -71,6 +75,7 @@ def main(args):
), "At least one of VGF must be True"
gfun = GeneralizedFunnelling(
dataset_name=args.dataset,
posterior=args.posteriors,
multilingual=args.multilingual,
wce=args.wce,
@ -93,8 +98,7 @@ def main(args):
# gfun.get_config()
gfun.fit(lX, lY)
if args.load_trained is None:
print("[NB: FORCE-SKIPPING MODEL SAVE]")
if args.load_trained is not None:
gfun.save()
# if not args.load_model: