model checkpoint during training. Restore best model if earlystop is triggered
This commit is contained in:
parent
9c2c43dafb
commit
3f3e4982e4
|
@ -34,6 +34,7 @@ class GeneralizedFunnelling:
|
||||||
optimc,
|
optimc,
|
||||||
device,
|
device,
|
||||||
load_trained,
|
load_trained,
|
||||||
|
dataset_name,
|
||||||
):
|
):
|
||||||
# Setting VFGs -----------
|
# Setting VFGs -----------
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
|
@ -63,6 +64,7 @@ class GeneralizedFunnelling:
|
||||||
self.metaclassifier = None
|
self.metaclassifier = None
|
||||||
self.aggfunc = "mean"
|
self.aggfunc = "mean"
|
||||||
self.load_trained = load_trained
|
self.load_trained = load_trained
|
||||||
|
self.dataset_name = dataset_name
|
||||||
self._init()
|
self._init()
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
|
@ -109,6 +111,7 @@ class GeneralizedFunnelling:
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
|
dataset_name=self.dataset_name,
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(transformer_vgf)
|
self.first_tier_learners.append(transformer_vgf)
|
||||||
|
|
||||||
|
|
|
@ -114,9 +114,11 @@ class Trainer:
|
||||||
self.evaluate_steps = evaluate_step
|
self.evaluate_steps = evaluate_step
|
||||||
self.loss_fn = loss_fn.to(device)
|
self.loss_fn = loss_fn.to(device)
|
||||||
self.print_steps = print_steps
|
self.print_steps = print_steps
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.patience = patience
|
||||||
self.earlystopping = EarlyStopping(
|
self.earlystopping = EarlyStopping(
|
||||||
patience=patience,
|
patience=patience,
|
||||||
checkpoint_path="models/vgfs/transformers/",
|
checkpoint_path="models/vgfs/transformer/",
|
||||||
verbose=True,
|
verbose=True,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
)
|
)
|
||||||
|
@ -129,12 +131,13 @@ class Trainer:
|
||||||
|
|
||||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||||
print(
|
print(
|
||||||
f"""- Training params:
|
f"""- Training params for {self.experiment_name}:
|
||||||
- epochs: {epochs}
|
- epochs: {epochs}
|
||||||
- learning rate: {self.optimizer.defaults['lr']}
|
- learning rate: {self.optimizer.defaults['lr']}
|
||||||
- train batch size: {train_dataloader.batch_size}
|
- train batch size: {train_dataloader.batch_size}
|
||||||
- eval batch size: {eval_dataloader.batch_size}
|
- eval batch size: {eval_dataloader.batch_size}
|
||||||
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
|
- max len: {train_dataloader.dataset.X.shape[-1]}
|
||||||
|
- patience: {self.earlystopping.patience}\n"""
|
||||||
)
|
)
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
self.train_epoch(train_dataloader, epoch)
|
self.train_epoch(train_dataloader, epoch)
|
||||||
|
@ -142,7 +145,17 @@ class Trainer:
|
||||||
metric_watcher = self.evaluate(eval_dataloader)
|
metric_watcher = self.evaluate(eval_dataloader)
|
||||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||||
if stop:
|
if stop:
|
||||||
|
print(
|
||||||
|
f"- restoring best model from epoch {self.earlystopping.best_epoch}"
|
||||||
|
)
|
||||||
|
self.model = self.earlystopping.load_model(self.model).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
# TODO: maybe a lower lr?
|
||||||
|
self.train_epoch(eval_dataloader, epoch=epoch)
|
||||||
|
print(f"\n- last swipe on eval set")
|
||||||
|
self.earlystopping.save_model(self.model)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def train_epoch(self, dataloader, epoch):
|
def train_epoch(self, dataloader, epoch):
|
||||||
|
@ -182,13 +195,14 @@ class Trainer:
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
|
# TODO: add checkpointing + restore model if early stopping + last swipe on validation set
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patience=5,
|
patience,
|
||||||
|
checkpoint_path,
|
||||||
|
experiment_name,
|
||||||
min_delta=0,
|
min_delta=0,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
checkpoint_path="checkpoint.pt",
|
|
||||||
experiment_name="experiment",
|
|
||||||
):
|
):
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.min_delta = min_delta
|
self.min_delta = min_delta
|
||||||
|
@ -206,7 +220,8 @@ class EarlyStopping:
|
||||||
)
|
)
|
||||||
self.best_score = validation
|
self.best_score = validation
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
# self.save_model(model)
|
self.best_epoch = epoch
|
||||||
|
self.save_model(model)
|
||||||
elif validation < (self.best_score + self.min_delta):
|
elif validation < (self.best_score + self.min_delta):
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
print(
|
print(
|
||||||
|
@ -219,6 +234,9 @@ class EarlyStopping:
|
||||||
|
|
||||||
def save_model(self, model):
|
def save_model(self, model):
|
||||||
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
||||||
print(f"- saving model to {_checkpoint_dir}")
|
|
||||||
os.makedirs(_checkpoint_dir, exist_ok=True)
|
os.makedirs(_checkpoint_dir, exist_ok=True)
|
||||||
model.save_pretrained(_checkpoint_dir)
|
model.save_pretrained(_checkpoint_dir)
|
||||||
|
|
||||||
|
def load_model(self, model):
|
||||||
|
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
||||||
|
return model.from_pretrained(_checkpoint_dir)
|
||||||
|
|
|
@ -27,6 +27,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name,
|
model_name,
|
||||||
|
dataset_name,
|
||||||
epochs=10,
|
epochs=10,
|
||||||
lr=1e-5,
|
lr=1e-5,
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
|
@ -42,6 +43,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name,
|
model_name,
|
||||||
|
dataset_name,
|
||||||
epochs,
|
epochs,
|
||||||
lr,
|
lr,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
@ -135,7 +137,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
experiment_name = (
|
||||||
|
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||||
|
)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
optimizer_name="adamW",
|
optimizer_name="adamW",
|
||||||
|
|
|
@ -12,6 +12,7 @@ class TransformerGen:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name,
|
model_name,
|
||||||
|
dataset_name,
|
||||||
epochs=10,
|
epochs=10,
|
||||||
lr=1e-5,
|
lr=1e-5,
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
|
@ -26,6 +27,7 @@ class TransformerGen:
|
||||||
patience=5,
|
patience=5,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.dataset_name = dataset_name
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = None
|
self.model = None
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
|
@ -44,6 +46,9 @@ class TransformerGen:
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.datasets = {}
|
self.datasets = {}
|
||||||
|
|
||||||
|
def make_probabilistic(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def build_dataloader(
|
def build_dataloader(
|
||||||
self,
|
self,
|
||||||
lX,
|
lX,
|
||||||
|
|
|
@ -16,9 +16,11 @@ transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
class VisualTransformerGen(ViewGen, TransformerGen):
|
class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
# TODO: probabilistic behaviour
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name,
|
model_name,
|
||||||
|
dataset_name,
|
||||||
lr=1e-5,
|
lr=1e-5,
|
||||||
epochs=10,
|
epochs=10,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
@ -29,6 +31,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name,
|
model_name,
|
||||||
|
dataset_name,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|
10
main.py
10
main.py
|
@ -13,6 +13,9 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
TODO:
|
TODO:
|
||||||
- add documentations sphinx
|
- add documentations sphinx
|
||||||
- zero-shot setup
|
- zero-shot setup
|
||||||
|
- set probabilistic behaviour in Transformer parent-class
|
||||||
|
- pooling / attention aggregation
|
||||||
|
- test split in MultiNews dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +41,7 @@ def get_dataset(datasetname):
|
||||||
dataset = (
|
dataset = (
|
||||||
MultilingualDataset(dataset_name="rcv1-2")
|
MultilingualDataset(dataset_name="rcv1-2")
|
||||||
.load(RCV_DATAPATH)
|
.load(RCV_DATAPATH)
|
||||||
.reduce_data(langs=["en", "it", "fr"], maxn=100)
|
.reduce_data(langs=["en", "it", "fr"], maxn=500)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -52,6 +55,7 @@ def main(args):
|
||||||
):
|
):
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
# lX_te, lY_te = dataset.test()
|
# lX_te, lY_te = dataset.test()
|
||||||
|
print("[NB: for debug purposes, training set is also used as test set]\n")
|
||||||
lX_te, lY_te = dataset.training()
|
lX_te, lY_te = dataset.training()
|
||||||
else:
|
else:
|
||||||
_lX = dataset.dX
|
_lX = dataset.dX
|
||||||
|
@ -71,6 +75,7 @@ def main(args):
|
||||||
), "At least one of VGF must be True"
|
), "At least one of VGF must be True"
|
||||||
|
|
||||||
gfun = GeneralizedFunnelling(
|
gfun = GeneralizedFunnelling(
|
||||||
|
dataset_name=args.dataset,
|
||||||
posterior=args.posteriors,
|
posterior=args.posteriors,
|
||||||
multilingual=args.multilingual,
|
multilingual=args.multilingual,
|
||||||
wce=args.wce,
|
wce=args.wce,
|
||||||
|
@ -93,8 +98,7 @@ def main(args):
|
||||||
# gfun.get_config()
|
# gfun.get_config()
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
if args.load_trained is None:
|
if args.load_trained is not None:
|
||||||
print("[NB: FORCE-SKIPPING MODEL SAVE]")
|
|
||||||
gfun.save()
|
gfun.save()
|
||||||
|
|
||||||
# if not args.load_model:
|
# if not args.load_model:
|
||||||
|
|
Loading…
Reference in New Issue