implemented save/load for MT5ForSequenceClassification. Moved torch Datasets to datamanager module

This commit is contained in:
Andrea Pedrotti 2023-03-16 10:31:34 +01:00
parent 56faaf2615
commit 9d43ebb23b
5 changed files with 98 additions and 112 deletions

View File

@ -1,2 +1,66 @@
class TorchMultiNewsDataset: import torch
pass from torch.utils.data import Dataset
class MultilingualDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -278,7 +278,7 @@ class Trainer:
loss = self.loss_fn(y_hat, y.to(self.device)) loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
batch_losses.append(loss.item()) # TODO: is this still on gpu? batch_losses.append(loss.item())
if (epoch + 1) % PRINT_ON_EPOCH == 0: if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print( print(

View File

@ -9,13 +9,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import transformers import transformers
from transformers import MT5EncoderModel from transformers import MT5EncoderModel
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.modeling_outputs import ModelOutput from transformers.modeling_outputs import ModelOutput
from gfun.vgfs.commons import Trainer from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultilingualDatasetTorch
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -44,11 +44,12 @@ class MT5ForSequenceClassification(nn.Module):
return ModelOutput(logits=logits) return ModelOutput(logits=logits)
def save_pretrained(self, checkpoint_dir): def save_pretrained(self, checkpoint_dir):
pass # TODO: implement torch.save(self.state_dict(), checkpoint_dir + ".pt")
return
def from_pretrained(self, checkpoint_dir): def from_pretrained(self, checkpoint_dir):
# TODO: implement checkpoint_dir += ".pt"
return self return self.load_state_dict(torch.load(checkpoint_dir))
class TextualTransformerGen(ViewGen, TransformerGen): class TextualTransformerGen(ViewGen, TransformerGen):
@ -165,9 +166,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
shuffle=False, shuffle=False,
) )
experiment_name = ( experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer( trainer = Trainer(
model=self.model, model=self.model,
@ -179,12 +178,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
patience=self.patience, patience=self.patience,
experiment_name=experiment_name, experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer", checkpoint_path=os.path.join(
"models",
"vgfs",
"transformer",
self._format_model_name(self.model_name),
),
vgf_name="textual_trf", vgf_name="textual_trf",
classification_type=self.clf_type, classification_type=self.clf_type,
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
# scheduler_name="ReduceLROnPlateau", scheduler_name="ReduceLROnPlateau",
scheduler_name=None, # scheduler_name=None,
) )
trainer.train( trainer.train(
train_dataloader=tra_dataloader, train_dataloader=tra_dataloader,
@ -259,39 +263,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
for param in self.model.parameters(): for param in self.model.parameters():
param.requires_grad = False param.requires_grad = False
def _format_model_name(self, model_name):
if "mt5" in model_name:
return "google-mt5"
elif "bert" in model_name:
if "multilingual" in model_name:
return "mbert"
elif "xlm" in model_name:
return "xlm"
else:
return model_name
def __str__(self): def __str__(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str return str
class MultilingualDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -4,12 +4,12 @@ import numpy as np
import torch import torch
import transformers import transformers
from PIL import Image from PIL import Image
from torch.utils.data import Dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification from transformers import AutoImageProcessor, AutoModelForImageClassification
from gfun.vgfs.commons import Trainer from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultilingualDatasetTorch
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -186,63 +186,3 @@ class VisualTransformerGen(ViewGen, TransformerGen):
def __str__(self): def __str__(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str return str
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
if __name__ == "__main__":
from os.path import expanduser
from dataManager.gFunDataset import gFunDataset
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=50,
)
vg = VisualTransformerGen(
dataset_name=dataset.dataset_name,
model_name="vit",
device="cuda",
epochs=5,
evaluate_step=10,
patience=10,
probabilistic=True,
)
lX, lY = dataset.training()
vg.fit(lX, lY)
out = vg.transform(lX)
exit(0)

View File

@ -1,13 +1,13 @@
beautifulsoup4==4.11.2 beautifulsoup4==4.11.2
joblib==1.2.0 joblib==1.2.0
matplotlib==3.7.1 matplotlib==3.6.3
numpy==1.24.2 numpy==1.24.1
pandas==1.5.3 pandas==1.5.3
Pillow==9.4.0 Pillow==9.4.0
requests==2.28.2 requests==2.28.2
scikit_learn==1.2.1 scikit_learn==1.2.2
scipy==1.10.1 scipy==1.10.1
torch==1.13.1 torch==1.13.1
torchtext==0.14.1 torchtext==0.14.1
tqdm==4.65.0 tqdm==4.64.1
transformers==4.26.1 transformers==4.26.0