model checkpoint during training. Restore best model if earlystop is triggered
This commit is contained in:
parent
9c2c43dafb
commit
3f3e4982e4
|
@ -34,6 +34,7 @@ class GeneralizedFunnelling:
|
|||
optimc,
|
||||
device,
|
||||
load_trained,
|
||||
dataset_name,
|
||||
):
|
||||
# Setting VFGs -----------
|
||||
self.posteriors_vgf = posterior
|
||||
|
@ -63,6 +64,7 @@ class GeneralizedFunnelling:
|
|||
self.metaclassifier = None
|
||||
self.aggfunc = "mean"
|
||||
self.load_trained = load_trained
|
||||
self.dataset_name = dataset_name
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
|
@ -109,6 +111,7 @@ class GeneralizedFunnelling:
|
|||
evaluate_step=self.evaluate_step,
|
||||
verbose=True,
|
||||
patience=self.patience,
|
||||
dataset_name=self.dataset_name,
|
||||
)
|
||||
self.first_tier_learners.append(transformer_vgf)
|
||||
|
||||
|
|
|
@ -114,9 +114,11 @@ class Trainer:
|
|||
self.evaluate_steps = evaluate_step
|
||||
self.loss_fn = loss_fn.to(device)
|
||||
self.print_steps = print_steps
|
||||
self.experiment_name = experiment_name
|
||||
self.patience = patience
|
||||
self.earlystopping = EarlyStopping(
|
||||
patience=patience,
|
||||
checkpoint_path="models/vgfs/transformers/",
|
||||
checkpoint_path="models/vgfs/transformer/",
|
||||
verbose=True,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
|
@ -129,12 +131,13 @@ class Trainer:
|
|||
|
||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||
print(
|
||||
f"""- Training params:
|
||||
f"""- Training params for {self.experiment_name}:
|
||||
- epochs: {epochs}
|
||||
- learning rate: {self.optimizer.defaults['lr']}
|
||||
- train batch size: {train_dataloader.batch_size}
|
||||
- eval batch size: {eval_dataloader.batch_size}
|
||||
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
|
||||
- max len: {train_dataloader.dataset.X.shape[-1]}
|
||||
- patience: {self.earlystopping.patience}\n"""
|
||||
)
|
||||
for epoch in range(epochs):
|
||||
self.train_epoch(train_dataloader, epoch)
|
||||
|
@ -142,7 +145,17 @@ class Trainer:
|
|||
metric_watcher = self.evaluate(eval_dataloader)
|
||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||
if stop:
|
||||
print(
|
||||
f"- restoring best model from epoch {self.earlystopping.best_epoch}"
|
||||
)
|
||||
self.model = self.earlystopping.load_model(self.model).to(
|
||||
self.device
|
||||
)
|
||||
break
|
||||
# TODO: maybe a lower lr?
|
||||
self.train_epoch(eval_dataloader, epoch=epoch)
|
||||
print(f"\n- last swipe on eval set")
|
||||
self.earlystopping.save_model(self.model)
|
||||
return self.model
|
||||
|
||||
def train_epoch(self, dataloader, epoch):
|
||||
|
@ -182,13 +195,14 @@ class Trainer:
|
|||
|
||||
|
||||
class EarlyStopping:
|
||||
# TODO: add checkpointing + restore model if early stopping + last swipe on validation set
|
||||
def __init__(
|
||||
self,
|
||||
patience=5,
|
||||
patience,
|
||||
checkpoint_path,
|
||||
experiment_name,
|
||||
min_delta=0,
|
||||
verbose=True,
|
||||
checkpoint_path="checkpoint.pt",
|
||||
experiment_name="experiment",
|
||||
):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
|
@ -206,7 +220,8 @@ class EarlyStopping:
|
|||
)
|
||||
self.best_score = validation
|
||||
self.counter = 0
|
||||
# self.save_model(model)
|
||||
self.best_epoch = epoch
|
||||
self.save_model(model)
|
||||
elif validation < (self.best_score + self.min_delta):
|
||||
self.counter += 1
|
||||
print(
|
||||
|
@ -219,6 +234,9 @@ class EarlyStopping:
|
|||
|
||||
def save_model(self, model):
|
||||
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
||||
print(f"- saving model to {_checkpoint_dir}")
|
||||
os.makedirs(_checkpoint_dir, exist_ok=True)
|
||||
model.save_pretrained(_checkpoint_dir)
|
||||
|
||||
def load_model(self, model):
|
||||
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
||||
return model.from_pretrained(_checkpoint_dir)
|
||||
|
|
|
@ -27,6 +27,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
dataset_name,
|
||||
epochs=10,
|
||||
lr=1e-5,
|
||||
batch_size=4,
|
||||
|
@ -42,6 +43,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
dataset_name,
|
||||
epochs,
|
||||
lr,
|
||||
batch_size,
|
||||
|
@ -135,7 +137,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
shuffle=False,
|
||||
)
|
||||
|
||||
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
||||
experiment_name = (
|
||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
)
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
optimizer_name="adamW",
|
||||
|
|
|
@ -12,6 +12,7 @@ class TransformerGen:
|
|||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
dataset_name,
|
||||
epochs=10,
|
||||
lr=1e-5,
|
||||
batch_size=4,
|
||||
|
@ -26,6 +27,7 @@ class TransformerGen:
|
|||
patience=5,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.dataset_name = dataset_name
|
||||
self.device = device
|
||||
self.model = None
|
||||
self.lr = lr
|
||||
|
@ -44,6 +46,9 @@ class TransformerGen:
|
|||
self.patience = patience
|
||||
self.datasets = {}
|
||||
|
||||
def make_probabilistic(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_dataloader(
|
||||
self,
|
||||
lX,
|
||||
|
|
|
@ -16,9 +16,11 @@ transformers.logging.set_verbosity_error()
|
|||
|
||||
|
||||
class VisualTransformerGen(ViewGen, TransformerGen):
|
||||
# TODO: probabilistic behaviour
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
dataset_name,
|
||||
lr=1e-5,
|
||||
epochs=10,
|
||||
batch_size=32,
|
||||
|
@ -29,6 +31,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
dataset_name,
|
||||
lr=lr,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
|
|
10
main.py
10
main.py
|
@ -13,6 +13,9 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|||
TODO:
|
||||
- add documentations sphinx
|
||||
- zero-shot setup
|
||||
- set probabilistic behaviour in Transformer parent-class
|
||||
- pooling / attention aggregation
|
||||
- test split in MultiNews dataset
|
||||
"""
|
||||
|
||||
|
||||
|
@ -38,7 +41,7 @@ def get_dataset(datasetname):
|
|||
dataset = (
|
||||
MultilingualDataset(dataset_name="rcv1-2")
|
||||
.load(RCV_DATAPATH)
|
||||
.reduce_data(langs=["en", "it", "fr"], maxn=100)
|
||||
.reduce_data(langs=["en", "it", "fr"], maxn=500)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -52,6 +55,7 @@ def main(args):
|
|||
):
|
||||
lX, lY = dataset.training()
|
||||
# lX_te, lY_te = dataset.test()
|
||||
print("[NB: for debug purposes, training set is also used as test set]\n")
|
||||
lX_te, lY_te = dataset.training()
|
||||
else:
|
||||
_lX = dataset.dX
|
||||
|
@ -71,6 +75,7 @@ def main(args):
|
|||
), "At least one of VGF must be True"
|
||||
|
||||
gfun = GeneralizedFunnelling(
|
||||
dataset_name=args.dataset,
|
||||
posterior=args.posteriors,
|
||||
multilingual=args.multilingual,
|
||||
wce=args.wce,
|
||||
|
@ -93,8 +98,7 @@ def main(args):
|
|||
# gfun.get_config()
|
||||
gfun.fit(lX, lY)
|
||||
|
||||
if args.load_trained is None:
|
||||
print("[NB: FORCE-SKIPPING MODEL SAVE]")
|
||||
if args.load_trained is not None:
|
||||
gfun.save()
|
||||
|
||||
# if not args.load_model:
|
||||
|
|
Loading…
Reference in New Issue