Visual VGF + MultiNewsDataset, working from data loading to testing

This commit is contained in:
Andrea Pedrotti 2023-02-09 18:42:27 +01:00
parent 1a3f931c70
commit 9c2c43dafb
7 changed files with 156 additions and 228 deletions

View File

@ -27,11 +27,11 @@ class MultiNewsDataset:
def __init__(self, data_dir, excluded_langs=[], debug=False): def __init__(self, data_dir, excluded_langs=[], debug=False):
self.debug = debug self.debug = debug
self.data_dir = data_dir self.data_dir = data_dir
self.langs = self.get_langs() self.dataset_langs = self.get_langs()
self.excluded_langs = excluded_langs self.excluded_langs = excluded_langs
self.lang_multiModalDataset = {} self.lang_multiModalDataset = {}
print( print(
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]" f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {[l for l in self.dataset_langs if l not in self.excluded_langs]}]"
) )
self.load_data() self.load_data()
self.all_labels = self.get_labels() self.all_labels = self.get_labels()
@ -39,12 +39,16 @@ class MultiNewsDataset:
self.print_stats() self.print_stats()
def load_data(self): def load_data(self):
for lang in self.langs: for lang in self.dataset_langs:
if lang not in self.excluded_langs: if lang not in self.excluded_langs:
self.lang_multiModalDataset[lang] = MultiModalDataset( self.lang_multiModalDataset[lang] = MultiModalDataset(
lang, join(self.data_dir, lang) lang, join(self.data_dir, lang)
) )
def langs(self):
return [l for l in self.dataset_langs if l not in self.excluded_langs]
return self.get_langs()
def get_langs(self): def get_langs(self):
from os import listdir from os import listdir
@ -56,13 +60,14 @@ class MultiNewsDataset:
def print_stats(self): def print_stats(self):
print(f"[MultiNewsDataset stats]") print(f"[MultiNewsDataset stats]")
total_docs = 0 total_docs = 0
for lang in self.langs: for lang in self.dataset_langs:
_len = len(self.lang_multiModalDataset[lang].data) if lang not in self.excluded_langs:
total_docs += _len _len = len(self.lang_multiModalDataset[lang].data)
print( total_docs += _len
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}" print(
) f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
print(f" - total docs: {total_docs}") )
print(f" - total docs: {total_docs}\n")
def _count_lang_labels(self, labels): def _count_lang_labels(self, labels):
lang_labels = set() lang_labels = set()
@ -77,11 +82,16 @@ class MultiNewsDataset:
raise NotImplementedError raise NotImplementedError
def training(self): def training(self):
# TODO: this is a (working) mess - clean this up
lXtr = {} lXtr = {}
lYtr = {} lYtr = {}
for lang, data in self.lang_multiModalDataset.items(): for lang, data in self.lang_multiModalDataset.items():
lXtr[lang] = data.data _data = [clean_text for _, clean_text, _, _ in data.data]
lYtr[lang] = self.label_binarizer.transform(data.labels) lXtr[lang] = _data
lYtr = {
lang: self.label_binarizer.transform(data.labels)
for lang, data in self.lang_multiModalDataset.items()
}
return lXtr, lYtr return lXtr, lYtr

View File

@ -78,7 +78,6 @@ class GeneralizedFunnelling:
if self.posteriors_vgf: if self.posteriors_vgf:
fun = VanillaFunGen( fun = VanillaFunGen(
base_learner=get_learner(calibrate=True), base_learner=get_learner(calibrate=True),
first_tier_parameters=None,
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
) )
self.first_tier_learners.append(fun) self.first_tier_learners.append(fun)

View File

@ -13,8 +13,8 @@ from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from vgfs.learners.svms import FeatureSet2Posteriors from vgfs.learners.svms import FeatureSet2Posteriors
from vgfs.viewGen import ViewGen from vgfs.viewGen import ViewGen
from vgfs.transformerGen import TransformerGen
from evaluation.evaluate import evaluate, log_eval from vgfs.commons import Trainer, predict
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -23,7 +23,7 @@ transformers.logging.set_verbosity_error()
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions # TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
class TextualTransformerGen(ViewGen): class TextualTransformerGen(ViewGen, TransformerGen):
def __init__( def __init__(
self, self,
model_name, model_name,
@ -40,23 +40,22 @@ class TextualTransformerGen(ViewGen):
verbose=False, verbose=False,
patience=5, patience=5,
): ):
self.model_name = model_name super().__init__(
self.device = device model_name,
self.model = None epochs,
self.lr = lr lr,
self.epochs = epochs batch_size,
self.tokenizer = None batch_size_eval,
self.max_length = max_length max_length,
self.batch_size = batch_size print_steps,
self.batch_size_eval = batch_size_eval device,
self.print_steps = print_steps probabilistic,
self.probabilistic = probabilistic n_jobs,
self.n_jobs = n_jobs evaluate_step,
verbose,
patience,
)
self.fitted = False self.fitted = False
self.datasets = {}
self.evaluate_step = evaluate_step
self.verbose = verbose
self.patience = patience
self._init() self._init()
def _init(self): def _init(self):
@ -93,25 +92,6 @@ class TextualTransformerGen(ViewGen):
model_name model_name
) )
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
for lang in lX.keys():
tr_X, val_X, tr_Y, val_Y = train_test_split(
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
)
tr_lX[lang] = tr_X
tr_lY[lang] = tr_Y
val_lX[lang] = val_X
val_lY[lang] = val_Y
return tr_lX, tr_lY, val_lX, val_lY
def build_dataloader(self, lX, lY, batch_size, split="train", shuffle=False):
l_tokenized = {lang: self._tokenize(data) for lang, data in lX.items()}
self.datasets[split] = MultilingualDatasetTorch(l_tokenized, lY, split=split)
return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle)
def _tokenize(self, X): def _tokenize(self, X):
return self.tokenizer( return self.tokenizer(
X, X,
@ -136,11 +116,23 @@ class TextualTransformerGen(ViewGen):
) )
tra_dataloader = self.build_dataloader( tra_dataloader = self.build_dataloader(
tr_lX, tr_lY, self.batch_size, split="train", shuffle=True tr_lX,
tr_lY,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size,
split="train",
shuffle=True,
) )
val_dataloader = self.build_dataloader( val_dataloader = self.build_dataloader(
val_lX, val_lY, self.batch_size_eval, split="val", shuffle=False val_lX,
val_lY,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size_eval,
split="val",
shuffle=False,
) )
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
@ -173,7 +165,13 @@ class TextualTransformerGen(ViewGen):
l_embeds = defaultdict(list) l_embeds = defaultdict(list)
dataloader = self.build_dataloader( dataloader = self.build_dataloader(
lX, lY=None, batch_size=self.batch_size_eval, split="whole", shuffle=False lX,
lY=None,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size_eval,
split="whole",
shuffle=False,
) )
self.model.eval() self.model.eval()
@ -245,146 +243,3 @@ class MultilingualDatasetTorch(Dataset):
if self.split == "whole": if self.split == "whole":
return self.X[index], self.langs[index] return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index]
class Trainer:
def __init__(
self,
model,
optimizer_name,
device,
loss_fn,
lr,
print_steps,
evaluate_step,
patience,
experiment_name,
):
self.device = device
self.model = model.to(device)
self.optimizer = self.init_optimizer(optimizer_name, lr)
self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path="models/vgfs/transformers/",
verbose=True,
experiment_name=experiment_name,
)
def init_optimizer(self, optimizer_name, lr):
if optimizer_name.lower() == "adamw":
return AdamW(self.model.parameters(), lr=lr)
else:
raise ValueError(f"Optimizer {optimizer_name} not supported")
def train(self, train_dataloader, eval_dataloader, epochs=10):
print(
f"""- Training params:
- epochs: {epochs}
- learning rate: {self.optimizer.defaults['lr']}
- train batch size: {train_dataloader.batch_size}
- eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
)
for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0:
metric_watcher = self.evaluate(eval_dataloader)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop:
break
return self.model
def train_epoch(self, dataloader, epoch):
self.model.train()
for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device))
loss = self.loss_fn(y_hat.logits, y.to(self.device))
loss.backward()
self.optimizer.step()
if b_idx % self.print_steps == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
return self
def evaluate(self, dataloader):
self.model.eval()
lY = defaultdict(list)
lY_hat = defaultdict(list)
for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device))
loss = self.loss_fn(y_hat.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel")
for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy())
lY_hat[l].append(_pred)
for lang in lY:
lY[lang] = np.vstack(lY[lang])
lY_hat[lang] = np.vstack(lY_hat[lang])
l_eval = evaluate(lY, lY_hat)
average_metrics = log_eval(l_eval, phase="validation")
return average_metrics[0] # macro-F1
class EarlyStopping:
def __init__(
self,
patience=5,
min_delta=0,
verbose=True,
checkpoint_path="checkpoint.pt",
experiment_name="experiment",
):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = 0
self.best_epoch = None
self.verbose = verbose
self.checkpoint_path = checkpoint_path
self.experiment_name = experiment_name
def __call__(self, validation, model, epoch):
if validation > self.best_score:
print(
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
)
self.best_score = validation
self.counter = 0
# self.save_model(model)
elif validation < (self.best_score + self.min_delta):
self.counter += 1
print(
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
)
if self.counter >= self.patience:
if self.verbose:
print(f"- earlystopping: Early stopping at epoch {epoch}")
return True
def save_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
print(f"- saving model to {_checkpoint_dir}")
os.makedirs(_checkpoint_dir, exist_ok=True)
model.save_pretrained(_checkpoint_dir)
def predict(logits, classification_type="multilabel"):
"""
Converts soft precictions to hard predictions [0,1]
"""
if classification_type == "multilabel":
prediction = torch.sigmoid(logits) > 0.5
elif classification_type == "singlelabel":
prediction = torch.argmax(logits, dim=1).view(-1, 1)
else:
print("unknown classification type")
return prediction.detach().cpu().numpy()

View File

@ -9,7 +9,39 @@ class TransformerGen:
form of dictioanries {lang: data} form of dictioanries {lang: data}
""" """
def __init__(self): def __init__(
self,
model_name,
epochs=10,
lr=1e-5,
batch_size=4,
batch_size_eval=32,
max_length=512,
print_steps=50,
device="cpu",
probabilistic=False,
n_jobs=-1,
evaluate_step=10,
verbose=False,
patience=5,
):
self.model_name = model_name
self.device = device
self.model = None
self.lr = lr
self.epochs = epochs
self.tokenizer = None
self.max_length = max_length
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval
self.print_steps = print_steps
self.probabilistic = probabilistic
self.n_jobs = n_jobs
self.fitted = False
self.datasets = {}
self.evaluate_step = evaluate_step
self.verbose = verbose
self.patience = patience
self.datasets = {} self.datasets = {}
def build_dataloader( def build_dataloader(

View File

@ -22,7 +22,6 @@ class VanillaFunGen(ViewGen):
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.doc_projector = NaivePolylingualClassifier( self.doc_projector = NaivePolylingualClassifier(
base_learner=self.learners, base_learner=self.learners,
parameters=self.first_tier_parameters,
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
) )
self.vectorizer = None self.vectorizer = None

View File

@ -10,21 +10,33 @@ from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from gfun.vgfs.commons import Trainer, predict from gfun.vgfs.commons import Trainer, predict
from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.transformerGen import TransformerGen
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer from transformers import AutoModelForImageClassification
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
class VisualTransformerGen(ViewGen, TransformerGen): class VisualTransformerGen(ViewGen, TransformerGen):
def __init__( def __init__(
self, model_name, lr=1e-5, epochs=10, batch_size=32, batch_size_eval=128 self,
model_name,
lr=1e-5,
epochs=10,
batch_size=32,
batch_size_eval=128,
evaluate_step=10,
device="cpu",
patience=5,
): ):
self.model_name = model_name super().__init__(
self.datasets = {} model_name,
self.lr = lr lr=lr,
self.epochs = epochs epochs=epochs,
self.batch_size = batch_size batch_size=batch_size,
self.batch_size_eval = batch_size_eval batch_size_eval=batch_size_eval,
device=device,
evaluate_step=evaluate_step,
patience=patience,
)
def _validate_model_name(self, model_name): def _validate_model_name(self, model_name):
if "vit" == model_name: if "vit" == model_name:
@ -33,10 +45,8 @@ class VisualTransformerGen(ViewGen, TransformerGen):
raise NotImplementedError raise NotImplementedError
def init_model(self, model_name, num_labels): def init_model(self, model_name, num_labels):
model = ( model = AutoModelForImageClassification.from_pretrained(
AutoModelForImageClassification.from_pretrained( model_name, num_labels=num_labels
model_name, num_labels=num_labels
),
) )
image_processor = AutoImageProcessor.from_pretrained(model_name) image_processor = AutoImageProcessor.from_pretrained(model_name)
transforms = self.init_preprocessor(image_processor) transforms = self.init_preprocessor(image_processor)
@ -100,9 +110,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
trainer = Trainer( trainer = Trainer(
model=self.model, model=self.model,
optimizer_name="adamW", optimizer_name="adamW",
lr=self.lr,
device=self.device, device=self.device,
loss_fn=torch.nn.CrossEntropyLoss(), loss_fn=torch.nn.CrossEntropyLoss(),
lr=self.lr,
print_steps=self.print_steps, print_steps=self.print_steps,
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
patience=self.patience, patience=self.patience,
@ -111,7 +121,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
trainer.train( trainer.train(
train_dataloader=tra_dataloader, train_dataloader=tra_dataloader,
val_dataloader=val_dataloader, eval_dataloader=val_dataloader,
epochs=self.epochs, epochs=self.epochs,
) )
@ -149,7 +159,7 @@ class MultimodalDatasetTorch(Dataset):
], ],
[], [],
) )
print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}") # print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}")
def __len__(self): def __len__(self):
return len(self.X) return len(self.X)
@ -169,7 +179,8 @@ if __name__ == "__main__":
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True) dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
lXtr, lYtr = dataset.training() lXtr, lYtr = dataset.training()
vg = VisualTransformerGen(model_name="vit") vg = VisualTransformerGen(
model_name="vit", device="cuda", epochs=1000, evaluate_step=10, patience=100
)
lX, lY = dataset.training() lX, lY = dataset.training()
vg.fit(lX, lY) vg.fit(lX, lY)
print("lel")

46
main.py
View File

@ -16,26 +16,46 @@ TODO:
""" """
def main(args): def get_dataset(datasetname):
# Loading dataset ------------------------ assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
RCV_DATAPATH = expanduser( RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
) )
# dataset = MultiNewsDataset(expanduser(args.dataset_path)) MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
# dataset = AmazonDataset(domains=args.domains,nrows=args.nrows,min_count=args.min_count,max_labels=args.max_labels) if datasetname == "multinews":
dataset = ( dataset = MultiNewsDataset(
MultilingualDataset(dataset_name="rcv1-2") expanduser(MULTINEWS_DATAPATH),
.load(RCV_DATAPATH) excluded_langs=["ar", "pe", "pl", "tr", "ua"],
.reduce_data(langs=["en", "it", "fr"], maxn=100) )
) elif datasetname == "amazon":
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif datasetname == "rcv1-2":
dataset = (
MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=100)
)
else:
raise NotImplementedError
return dataset
if isinstance(dataset, MultilingualDataset):
def main(args):
dataset = get_dataset(args.dataset)
if isinstance(dataset, MultilingualDataset) or isinstance(
dataset, MultiNewsDataset
):
lX, lY = dataset.training() lX, lY = dataset.training()
lX_te, lY_te = dataset.test() # lX_te, lY_te = dataset.test()
lX_te, lY_te = dataset.training()
else: else:
_lX = dataset.dX _lX = dataset.dX
_lY = dataset.dY _lY = dataset.dY
# ----------------------------------------
tinit = time() tinit = time()
@ -74,6 +94,7 @@ def main(args):
gfun.fit(lX, lY) gfun.fit(lX, lY)
if args.load_trained is None: if args.load_trained is None:
print("[NB: FORCE-SKIPPING MODEL SAVE]")
gfun.save() gfun.save()
# if not args.load_model: # if not args.load_model:
@ -98,6 +119,7 @@ if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("-l", "--load_trained", type=str, default=None)
# Dataset parameters ------------------- # Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="multinews")
parser.add_argument("--domains", type=str, default="all") parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=10000) parser.add_argument("--nrows", type=int, default=10000)
parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--min_count", type=int, default=10)