implemented save/load for MT5ForSequenceClassification. Moved torch Datasets to datamanager module
This commit is contained in:
parent
56faaf2615
commit
9d43ebb23b
|
|
@ -1,2 +1,66 @@
|
||||||
class TorchMultiNewsDataset:
|
import torch
|
||||||
pass
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MultilingualDatasetTorch(Dataset):
|
||||||
|
def __init__(self, lX, lY, split="train"):
|
||||||
|
self.lX = lX
|
||||||
|
self.lY = lY
|
||||||
|
self.split = split
|
||||||
|
self.langs = []
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
||||||
|
if self.split != "whole":
|
||||||
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||||
|
self.langs = sum(
|
||||||
|
[
|
||||||
|
v
|
||||||
|
for v in {
|
||||||
|
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
||||||
|
}.values()
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.X)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.split == "whole":
|
||||||
|
return self.X[index], self.langs[index]
|
||||||
|
return self.X[index], self.Y[index], self.langs[index]
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalDatasetTorch(Dataset):
|
||||||
|
def __init__(self, lX, lY, split="train"):
|
||||||
|
self.lX = lX
|
||||||
|
self.lY = lY
|
||||||
|
self.split = split
|
||||||
|
self.langs = []
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
||||||
|
if self.split != "whole":
|
||||||
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||||
|
self.langs = sum(
|
||||||
|
[
|
||||||
|
v
|
||||||
|
for v in {
|
||||||
|
lang: [lang] * len(data) for lang, data in self.lX.items()
|
||||||
|
}.values()
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.X)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.split == "whole":
|
||||||
|
return self.X[index], self.langs[index]
|
||||||
|
return self.X[index], self.Y[index], self.langs[index]
|
||||||
|
|
|
||||||
|
|
@ -278,7 +278,7 @@ class Trainer:
|
||||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
batch_losses.append(loss.item()) # TODO: is this still on gpu?
|
batch_losses.append(loss.item())
|
||||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||||
print(
|
print(
|
||||||
|
|
|
||||||
|
|
@ -9,13 +9,13 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import MT5EncoderModel
|
from transformers import MT5EncoderModel
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
from transformers.modeling_outputs import ModelOutput
|
from transformers.modeling_outputs import ModelOutput
|
||||||
|
|
||||||
from gfun.vgfs.commons import Trainer
|
from gfun.vgfs.commons import Trainer
|
||||||
from gfun.vgfs.transformerGen import TransformerGen
|
from gfun.vgfs.transformerGen import TransformerGen
|
||||||
from gfun.vgfs.viewGen import ViewGen
|
from gfun.vgfs.viewGen import ViewGen
|
||||||
|
from dataManager.torchDataset import MultilingualDatasetTorch
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
@ -44,11 +44,12 @@ class MT5ForSequenceClassification(nn.Module):
|
||||||
return ModelOutput(logits=logits)
|
return ModelOutput(logits=logits)
|
||||||
|
|
||||||
def save_pretrained(self, checkpoint_dir):
|
def save_pretrained(self, checkpoint_dir):
|
||||||
pass # TODO: implement
|
torch.save(self.state_dict(), checkpoint_dir + ".pt")
|
||||||
|
return
|
||||||
|
|
||||||
def from_pretrained(self, checkpoint_dir):
|
def from_pretrained(self, checkpoint_dir):
|
||||||
# TODO: implement
|
checkpoint_dir += ".pt"
|
||||||
return self
|
return self.load_state_dict(torch.load(checkpoint_dir))
|
||||||
|
|
||||||
|
|
||||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
|
@ -165,9 +166,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment_name = (
|
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
|
@ -179,12 +178,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
checkpoint_path="models/vgfs/transformer",
|
checkpoint_path=os.path.join(
|
||||||
|
"models",
|
||||||
|
"vgfs",
|
||||||
|
"transformer",
|
||||||
|
self._format_model_name(self.model_name),
|
||||||
|
),
|
||||||
vgf_name="textual_trf",
|
vgf_name="textual_trf",
|
||||||
classification_type=self.clf_type,
|
classification_type=self.clf_type,
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
# scheduler_name="ReduceLROnPlateau",
|
scheduler_name="ReduceLROnPlateau",
|
||||||
scheduler_name=None,
|
# scheduler_name=None,
|
||||||
)
|
)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
|
|
@ -259,39 +263,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
for param in self.model.parameters():
|
for param in self.model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _format_model_name(self, model_name):
|
||||||
|
if "mt5" in model_name:
|
||||||
|
return "google-mt5"
|
||||||
|
elif "bert" in model_name:
|
||||||
|
if "multilingual" in model_name:
|
||||||
|
return "mbert"
|
||||||
|
elif "xlm" in model_name:
|
||||||
|
return "xlm"
|
||||||
|
else:
|
||||||
|
return model_name
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||||
return str
|
return str
|
||||||
|
|
||||||
|
|
||||||
class MultilingualDatasetTorch(Dataset):
|
|
||||||
def __init__(self, lX, lY, split="train"):
|
|
||||||
self.lX = lX
|
|
||||||
self.lY = lY
|
|
||||||
self.split = split
|
|
||||||
self.langs = []
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
|
||||||
if self.split != "whole":
|
|
||||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
||||||
self.langs = sum(
|
|
||||||
[
|
|
||||||
v
|
|
||||||
for v in {
|
|
||||||
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
|
||||||
}.values()
|
|
||||||
],
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.X)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.split == "whole":
|
|
||||||
return self.X[index], self.langs[index]
|
|
||||||
return self.X[index], self.Y[index], self.langs[index]
|
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,12 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
|
||||||
from gfun.vgfs.commons import Trainer
|
from gfun.vgfs.commons import Trainer
|
||||||
from gfun.vgfs.transformerGen import TransformerGen
|
from gfun.vgfs.transformerGen import TransformerGen
|
||||||
from gfun.vgfs.viewGen import ViewGen
|
from gfun.vgfs.viewGen import ViewGen
|
||||||
|
from dataManager.torchDataset import MultilingualDatasetTorch
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
@ -186,63 +186,3 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||||
return str
|
return str
|
||||||
|
|
||||||
|
|
||||||
class MultimodalDatasetTorch(Dataset):
|
|
||||||
def __init__(self, lX, lY, split="train"):
|
|
||||||
self.lX = lX
|
|
||||||
self.lY = lY
|
|
||||||
self.split = split
|
|
||||||
self.langs = []
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
|
||||||
if self.split != "whole":
|
|
||||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
||||||
self.langs = sum(
|
|
||||||
[
|
|
||||||
v
|
|
||||||
for v in {
|
|
||||||
lang: [lang] * len(data) for lang, data in self.lX.items()
|
|
||||||
}.values()
|
|
||||||
],
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.X)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.split == "whole":
|
|
||||||
return self.X[index], self.langs[index]
|
|
||||||
return self.X[index], self.Y[index], self.langs[index]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from os.path import expanduser
|
|
||||||
|
|
||||||
from dataManager.gFunDataset import gFunDataset
|
|
||||||
|
|
||||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=GLAMI_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=True,
|
|
||||||
is_multilabel=False,
|
|
||||||
nrows=50,
|
|
||||||
)
|
|
||||||
|
|
||||||
vg = VisualTransformerGen(
|
|
||||||
dataset_name=dataset.dataset_name,
|
|
||||||
model_name="vit",
|
|
||||||
device="cuda",
|
|
||||||
epochs=5,
|
|
||||||
evaluate_step=10,
|
|
||||||
patience=10,
|
|
||||||
probabilistic=True,
|
|
||||||
)
|
|
||||||
lX, lY = dataset.training()
|
|
||||||
vg.fit(lX, lY)
|
|
||||||
out = vg.transform(lX)
|
|
||||||
exit(0)
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
beautifulsoup4==4.11.2
|
beautifulsoup4==4.11.2
|
||||||
joblib==1.2.0
|
joblib==1.2.0
|
||||||
matplotlib==3.7.1
|
matplotlib==3.6.3
|
||||||
numpy==1.24.2
|
numpy==1.24.1
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
Pillow==9.4.0
|
Pillow==9.4.0
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
scikit_learn==1.2.1
|
scikit_learn==1.2.2
|
||||||
scipy==1.10.1
|
scipy==1.10.1
|
||||||
torch==1.13.1
|
torch==1.13.1
|
||||||
torchtext==0.14.1
|
torchtext==0.14.1
|
||||||
tqdm==4.65.0
|
tqdm==4.64.1
|
||||||
transformers==4.26.1
|
transformers==4.26.0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue