implemented save/load for MT5ForSequenceClassification. Moved torch Datasets to datamanager module

This commit is contained in:
Andrea Pedrotti 2023-03-16 10:31:34 +01:00
parent 56faaf2615
commit 9d43ebb23b
5 changed files with 98 additions and 112 deletions

View File

@ -1,2 +1,66 @@
class TorchMultiNewsDataset:
pass
import torch
from torch.utils.data import Dataset
class MultilingualDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -278,7 +278,7 @@ class Trainer:
loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward()
self.optimizer.step()
batch_losses.append(loss.item()) # TODO: is this still on gpu?
batch_losses.append(loss.item())
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(

View File

@ -9,13 +9,13 @@ import torch
import torch.nn as nn
import transformers
from transformers import MT5EncoderModel
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultilingualDatasetTorch
transformers.logging.set_verbosity_error()
@ -44,11 +44,12 @@ class MT5ForSequenceClassification(nn.Module):
return ModelOutput(logits=logits)
def save_pretrained(self, checkpoint_dir):
pass # TODO: implement
torch.save(self.state_dict(), checkpoint_dir + ".pt")
return
def from_pretrained(self, checkpoint_dir):
# TODO: implement
return self
checkpoint_dir += ".pt"
return self.load_state_dict(torch.load(checkpoint_dir))
class TextualTransformerGen(ViewGen, TransformerGen):
@ -165,9 +166,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
shuffle=False,
)
experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
trainer = Trainer(
model=self.model,
@ -179,12 +178,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
checkpoint_path=os.path.join(
"models",
"vgfs",
"transformer",
self._format_model_name(self.model_name),
),
vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
# scheduler_name="ReduceLROnPlateau",
scheduler_name=None,
scheduler_name="ReduceLROnPlateau",
# scheduler_name=None,
)
trainer.train(
train_dataloader=tra_dataloader,
@ -259,39 +263,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
for param in self.model.parameters():
param.requires_grad = False
def _format_model_name(self, model_name):
if "mt5" in model_name:
return "google-mt5"
elif "bert" in model_name:
if "multilingual" in model_name:
return "mbert"
elif "xlm" in model_name:
return "xlm"
else:
return model_name
def __str__(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
class MultilingualDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -4,12 +4,12 @@ import numpy as np
import torch
import transformers
from PIL import Image
from torch.utils.data import Dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultilingualDatasetTorch
transformers.logging.set_verbosity_error()
@ -186,63 +186,3 @@ class VisualTransformerGen(ViewGen, TransformerGen):
def __str__(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
if __name__ == "__main__":
from os.path import expanduser
from dataManager.gFunDataset import gFunDataset
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=50,
)
vg = VisualTransformerGen(
dataset_name=dataset.dataset_name,
model_name="vit",
device="cuda",
epochs=5,
evaluate_step=10,
patience=10,
probabilistic=True,
)
lX, lY = dataset.training()
vg.fit(lX, lY)
out = vg.transform(lX)
exit(0)

View File

@ -1,13 +1,13 @@
beautifulsoup4==4.11.2
joblib==1.2.0
matplotlib==3.7.1
numpy==1.24.2
matplotlib==3.6.3
numpy==1.24.1
pandas==1.5.3
Pillow==9.4.0
requests==2.28.2
scikit_learn==1.2.1
scikit_learn==1.2.2
scipy==1.10.1
torch==1.13.1
torchtext==0.14.1
tqdm==4.65.0
transformers==4.26.1
tqdm==4.64.1
transformers==4.26.0