Visual VGF + MultiNewsDataset, working from data loading to testing
This commit is contained in:
parent
1a3f931c70
commit
9c2c43dafb
|
@ -27,11 +27,11 @@ class MultiNewsDataset:
|
||||||
def __init__(self, data_dir, excluded_langs=[], debug=False):
|
def __init__(self, data_dir, excluded_langs=[], debug=False):
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
self.langs = self.get_langs()
|
self.dataset_langs = self.get_langs()
|
||||||
self.excluded_langs = excluded_langs
|
self.excluded_langs = excluded_langs
|
||||||
self.lang_multiModalDataset = {}
|
self.lang_multiModalDataset = {}
|
||||||
print(
|
print(
|
||||||
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]"
|
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {[l for l in self.dataset_langs if l not in self.excluded_langs]}]"
|
||||||
)
|
)
|
||||||
self.load_data()
|
self.load_data()
|
||||||
self.all_labels = self.get_labels()
|
self.all_labels = self.get_labels()
|
||||||
|
@ -39,12 +39,16 @@ class MultiNewsDataset:
|
||||||
self.print_stats()
|
self.print_stats()
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
for lang in self.langs:
|
for lang in self.dataset_langs:
|
||||||
if lang not in self.excluded_langs:
|
if lang not in self.excluded_langs:
|
||||||
self.lang_multiModalDataset[lang] = MultiModalDataset(
|
self.lang_multiModalDataset[lang] = MultiModalDataset(
|
||||||
lang, join(self.data_dir, lang)
|
lang, join(self.data_dir, lang)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def langs(self):
|
||||||
|
return [l for l in self.dataset_langs if l not in self.excluded_langs]
|
||||||
|
return self.get_langs()
|
||||||
|
|
||||||
def get_langs(self):
|
def get_langs(self):
|
||||||
from os import listdir
|
from os import listdir
|
||||||
|
|
||||||
|
@ -56,13 +60,14 @@ class MultiNewsDataset:
|
||||||
def print_stats(self):
|
def print_stats(self):
|
||||||
print(f"[MultiNewsDataset stats]")
|
print(f"[MultiNewsDataset stats]")
|
||||||
total_docs = 0
|
total_docs = 0
|
||||||
for lang in self.langs:
|
for lang in self.dataset_langs:
|
||||||
_len = len(self.lang_multiModalDataset[lang].data)
|
if lang not in self.excluded_langs:
|
||||||
total_docs += _len
|
_len = len(self.lang_multiModalDataset[lang].data)
|
||||||
print(
|
total_docs += _len
|
||||||
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
|
print(
|
||||||
)
|
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
|
||||||
print(f" - total docs: {total_docs}")
|
)
|
||||||
|
print(f" - total docs: {total_docs}\n")
|
||||||
|
|
||||||
def _count_lang_labels(self, labels):
|
def _count_lang_labels(self, labels):
|
||||||
lang_labels = set()
|
lang_labels = set()
|
||||||
|
@ -77,11 +82,16 @@ class MultiNewsDataset:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def training(self):
|
def training(self):
|
||||||
|
# TODO: this is a (working) mess - clean this up
|
||||||
lXtr = {}
|
lXtr = {}
|
||||||
lYtr = {}
|
lYtr = {}
|
||||||
for lang, data in self.lang_multiModalDataset.items():
|
for lang, data in self.lang_multiModalDataset.items():
|
||||||
lXtr[lang] = data.data
|
_data = [clean_text for _, clean_text, _, _ in data.data]
|
||||||
lYtr[lang] = self.label_binarizer.transform(data.labels)
|
lXtr[lang] = _data
|
||||||
|
lYtr = {
|
||||||
|
lang: self.label_binarizer.transform(data.labels)
|
||||||
|
for lang, data in self.lang_multiModalDataset.items()
|
||||||
|
}
|
||||||
|
|
||||||
return lXtr, lYtr
|
return lXtr, lYtr
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,6 @@ class GeneralizedFunnelling:
|
||||||
if self.posteriors_vgf:
|
if self.posteriors_vgf:
|
||||||
fun = VanillaFunGen(
|
fun = VanillaFunGen(
|
||||||
base_learner=get_learner(calibrate=True),
|
base_learner=get_learner(calibrate=True),
|
||||||
first_tier_parameters=None,
|
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(fun)
|
self.first_tier_learners.append(fun)
|
||||||
|
|
|
@ -13,8 +13,8 @@ from torch.utils.data import DataLoader, Dataset
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
from vgfs.learners.svms import FeatureSet2Posteriors
|
from vgfs.learners.svms import FeatureSet2Posteriors
|
||||||
from vgfs.viewGen import ViewGen
|
from vgfs.viewGen import ViewGen
|
||||||
|
from vgfs.transformerGen import TransformerGen
|
||||||
from evaluation.evaluate import evaluate, log_eval
|
from vgfs.commons import Trainer, predict
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ transformers.logging.set_verbosity_error()
|
||||||
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
|
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
|
||||||
|
|
||||||
|
|
||||||
class TextualTransformerGen(ViewGen):
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name,
|
model_name,
|
||||||
|
@ -40,23 +40,22 @@ class TextualTransformerGen(ViewGen):
|
||||||
verbose=False,
|
verbose=False,
|
||||||
patience=5,
|
patience=5,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
super().__init__(
|
||||||
self.device = device
|
model_name,
|
||||||
self.model = None
|
epochs,
|
||||||
self.lr = lr
|
lr,
|
||||||
self.epochs = epochs
|
batch_size,
|
||||||
self.tokenizer = None
|
batch_size_eval,
|
||||||
self.max_length = max_length
|
max_length,
|
||||||
self.batch_size = batch_size
|
print_steps,
|
||||||
self.batch_size_eval = batch_size_eval
|
device,
|
||||||
self.print_steps = print_steps
|
probabilistic,
|
||||||
self.probabilistic = probabilistic
|
n_jobs,
|
||||||
self.n_jobs = n_jobs
|
evaluate_step,
|
||||||
|
verbose,
|
||||||
|
patience,
|
||||||
|
)
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
self.datasets = {}
|
|
||||||
self.evaluate_step = evaluate_step
|
|
||||||
self.verbose = verbose
|
|
||||||
self.patience = patience
|
|
||||||
self._init()
|
self._init()
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
|
@ -93,25 +92,6 @@ class TextualTransformerGen(ViewGen):
|
||||||
model_name
|
model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
|
|
||||||
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
|
|
||||||
|
|
||||||
for lang in lX.keys():
|
|
||||||
tr_X, val_X, tr_Y, val_Y = train_test_split(
|
|
||||||
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
|
|
||||||
)
|
|
||||||
tr_lX[lang] = tr_X
|
|
||||||
tr_lY[lang] = tr_Y
|
|
||||||
val_lX[lang] = val_X
|
|
||||||
val_lY[lang] = val_Y
|
|
||||||
|
|
||||||
return tr_lX, tr_lY, val_lX, val_lY
|
|
||||||
|
|
||||||
def build_dataloader(self, lX, lY, batch_size, split="train", shuffle=False):
|
|
||||||
l_tokenized = {lang: self._tokenize(data) for lang, data in lX.items()}
|
|
||||||
self.datasets[split] = MultilingualDatasetTorch(l_tokenized, lY, split=split)
|
|
||||||
return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle)
|
|
||||||
|
|
||||||
def _tokenize(self, X):
|
def _tokenize(self, X):
|
||||||
return self.tokenizer(
|
return self.tokenizer(
|
||||||
X,
|
X,
|
||||||
|
@ -136,11 +116,23 @@ class TextualTransformerGen(ViewGen):
|
||||||
)
|
)
|
||||||
|
|
||||||
tra_dataloader = self.build_dataloader(
|
tra_dataloader = self.build_dataloader(
|
||||||
tr_lX, tr_lY, self.batch_size, split="train", shuffle=True
|
tr_lX,
|
||||||
|
tr_lY,
|
||||||
|
processor_fn=self._tokenize,
|
||||||
|
torchDataset=MultilingualDatasetTorch,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
split="train",
|
||||||
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = self.build_dataloader(
|
val_dataloader = self.build_dataloader(
|
||||||
val_lX, val_lY, self.batch_size_eval, split="val", shuffle=False
|
val_lX,
|
||||||
|
val_lY,
|
||||||
|
processor_fn=self._tokenize,
|
||||||
|
torchDataset=MultilingualDatasetTorch,
|
||||||
|
batch_size=self.batch_size_eval,
|
||||||
|
split="val",
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
||||||
|
@ -173,7 +165,13 @@ class TextualTransformerGen(ViewGen):
|
||||||
l_embeds = defaultdict(list)
|
l_embeds = defaultdict(list)
|
||||||
|
|
||||||
dataloader = self.build_dataloader(
|
dataloader = self.build_dataloader(
|
||||||
lX, lY=None, batch_size=self.batch_size_eval, split="whole", shuffle=False
|
lX,
|
||||||
|
lY=None,
|
||||||
|
processor_fn=self._tokenize,
|
||||||
|
torchDataset=MultilingualDatasetTorch,
|
||||||
|
batch_size=self.batch_size_eval,
|
||||||
|
split="whole",
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -245,146 +243,3 @@ class MultilingualDatasetTorch(Dataset):
|
||||||
if self.split == "whole":
|
if self.split == "whole":
|
||||||
return self.X[index], self.langs[index]
|
return self.X[index], self.langs[index]
|
||||||
return self.X[index], self.Y[index], self.langs[index]
|
return self.X[index], self.Y[index], self.langs[index]
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
optimizer_name,
|
|
||||||
device,
|
|
||||||
loss_fn,
|
|
||||||
lr,
|
|
||||||
print_steps,
|
|
||||||
evaluate_step,
|
|
||||||
patience,
|
|
||||||
experiment_name,
|
|
||||||
):
|
|
||||||
self.device = device
|
|
||||||
self.model = model.to(device)
|
|
||||||
self.optimizer = self.init_optimizer(optimizer_name, lr)
|
|
||||||
self.evaluate_steps = evaluate_step
|
|
||||||
self.loss_fn = loss_fn.to(device)
|
|
||||||
self.print_steps = print_steps
|
|
||||||
self.earlystopping = EarlyStopping(
|
|
||||||
patience=patience,
|
|
||||||
checkpoint_path="models/vgfs/transformers/",
|
|
||||||
verbose=True,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_optimizer(self, optimizer_name, lr):
|
|
||||||
if optimizer_name.lower() == "adamw":
|
|
||||||
return AdamW(self.model.parameters(), lr=lr)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
|
||||||
|
|
||||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
|
||||||
print(
|
|
||||||
f"""- Training params:
|
|
||||||
- epochs: {epochs}
|
|
||||||
- learning rate: {self.optimizer.defaults['lr']}
|
|
||||||
- train batch size: {train_dataloader.batch_size}
|
|
||||||
- eval batch size: {eval_dataloader.batch_size}
|
|
||||||
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
|
|
||||||
)
|
|
||||||
for epoch in range(epochs):
|
|
||||||
self.train_epoch(train_dataloader, epoch)
|
|
||||||
if (epoch + 1) % self.evaluate_steps == 0:
|
|
||||||
metric_watcher = self.evaluate(eval_dataloader)
|
|
||||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
|
||||||
if stop:
|
|
||||||
break
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def train_epoch(self, dataloader, epoch):
|
|
||||||
self.model.train()
|
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
y_hat = self.model(x.to(self.device))
|
|
||||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
if b_idx % self.print_steps == 0:
|
|
||||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
|
||||||
return self
|
|
||||||
|
|
||||||
def evaluate(self, dataloader):
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
lY = defaultdict(list)
|
|
||||||
lY_hat = defaultdict(list)
|
|
||||||
|
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
|
||||||
y_hat = self.model(x.to(self.device))
|
|
||||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
|
||||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
|
||||||
|
|
||||||
for l, _true, _pred in zip(lang, y, predictions):
|
|
||||||
lY[l].append(_true.detach().cpu().numpy())
|
|
||||||
lY_hat[l].append(_pred)
|
|
||||||
|
|
||||||
for lang in lY:
|
|
||||||
lY[lang] = np.vstack(lY[lang])
|
|
||||||
lY_hat[lang] = np.vstack(lY_hat[lang])
|
|
||||||
|
|
||||||
l_eval = evaluate(lY, lY_hat)
|
|
||||||
average_metrics = log_eval(l_eval, phase="validation")
|
|
||||||
return average_metrics[0] # macro-F1
|
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patience=5,
|
|
||||||
min_delta=0,
|
|
||||||
verbose=True,
|
|
||||||
checkpoint_path="checkpoint.pt",
|
|
||||||
experiment_name="experiment",
|
|
||||||
):
|
|
||||||
self.patience = patience
|
|
||||||
self.min_delta = min_delta
|
|
||||||
self.counter = 0
|
|
||||||
self.best_score = 0
|
|
||||||
self.best_epoch = None
|
|
||||||
self.verbose = verbose
|
|
||||||
self.checkpoint_path = checkpoint_path
|
|
||||||
self.experiment_name = experiment_name
|
|
||||||
|
|
||||||
def __call__(self, validation, model, epoch):
|
|
||||||
if validation > self.best_score:
|
|
||||||
print(
|
|
||||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
|
||||||
)
|
|
||||||
self.best_score = validation
|
|
||||||
self.counter = 0
|
|
||||||
# self.save_model(model)
|
|
||||||
elif validation < (self.best_score + self.min_delta):
|
|
||||||
self.counter += 1
|
|
||||||
print(
|
|
||||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
|
||||||
)
|
|
||||||
if self.counter >= self.patience:
|
|
||||||
if self.verbose:
|
|
||||||
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def save_model(self, model):
|
|
||||||
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
|
||||||
print(f"- saving model to {_checkpoint_dir}")
|
|
||||||
os.makedirs(_checkpoint_dir, exist_ok=True)
|
|
||||||
model.save_pretrained(_checkpoint_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def predict(logits, classification_type="multilabel"):
|
|
||||||
"""
|
|
||||||
Converts soft precictions to hard predictions [0,1]
|
|
||||||
"""
|
|
||||||
if classification_type == "multilabel":
|
|
||||||
prediction = torch.sigmoid(logits) > 0.5
|
|
||||||
elif classification_type == "singlelabel":
|
|
||||||
prediction = torch.argmax(logits, dim=1).view(-1, 1)
|
|
||||||
else:
|
|
||||||
print("unknown classification type")
|
|
||||||
|
|
||||||
return prediction.detach().cpu().numpy()
|
|
||||||
|
|
|
@ -9,7 +9,39 @@ class TransformerGen:
|
||||||
form of dictioanries {lang: data}
|
form of dictioanries {lang: data}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name,
|
||||||
|
epochs=10,
|
||||||
|
lr=1e-5,
|
||||||
|
batch_size=4,
|
||||||
|
batch_size_eval=32,
|
||||||
|
max_length=512,
|
||||||
|
print_steps=50,
|
||||||
|
device="cpu",
|
||||||
|
probabilistic=False,
|
||||||
|
n_jobs=-1,
|
||||||
|
evaluate_step=10,
|
||||||
|
verbose=False,
|
||||||
|
patience=5,
|
||||||
|
):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.device = device
|
||||||
|
self.model = None
|
||||||
|
self.lr = lr
|
||||||
|
self.epochs = epochs
|
||||||
|
self.tokenizer = None
|
||||||
|
self.max_length = max_length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.batch_size_eval = batch_size_eval
|
||||||
|
self.print_steps = print_steps
|
||||||
|
self.probabilistic = probabilistic
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
self.fitted = False
|
||||||
|
self.datasets = {}
|
||||||
|
self.evaluate_step = evaluate_step
|
||||||
|
self.verbose = verbose
|
||||||
|
self.patience = patience
|
||||||
self.datasets = {}
|
self.datasets = {}
|
||||||
|
|
||||||
def build_dataloader(
|
def build_dataloader(
|
||||||
|
|
|
@ -22,7 +22,6 @@ class VanillaFunGen(ViewGen):
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.doc_projector = NaivePolylingualClassifier(
|
self.doc_projector = NaivePolylingualClassifier(
|
||||||
base_learner=self.learners,
|
base_learner=self.learners,
|
||||||
parameters=self.first_tier_parameters,
|
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
self.vectorizer = None
|
self.vectorizer = None
|
||||||
|
|
|
@ -10,21 +10,33 @@ from torch.utils.data import DataLoader, Dataset
|
||||||
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
||||||
from gfun.vgfs.commons import Trainer, predict
|
from gfun.vgfs.commons import Trainer, predict
|
||||||
from gfun.vgfs.transformerGen import TransformerGen
|
from gfun.vgfs.transformerGen import TransformerGen
|
||||||
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
|
from transformers import AutoModelForImageClassification
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
class VisualTransformerGen(ViewGen, TransformerGen):
|
class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name, lr=1e-5, epochs=10, batch_size=32, batch_size_eval=128
|
self,
|
||||||
|
model_name,
|
||||||
|
lr=1e-5,
|
||||||
|
epochs=10,
|
||||||
|
batch_size=32,
|
||||||
|
batch_size_eval=128,
|
||||||
|
evaluate_step=10,
|
||||||
|
device="cpu",
|
||||||
|
patience=5,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
super().__init__(
|
||||||
self.datasets = {}
|
model_name,
|
||||||
self.lr = lr
|
lr=lr,
|
||||||
self.epochs = epochs
|
epochs=epochs,
|
||||||
self.batch_size = batch_size
|
batch_size=batch_size,
|
||||||
self.batch_size_eval = batch_size_eval
|
batch_size_eval=batch_size_eval,
|
||||||
|
device=device,
|
||||||
|
evaluate_step=evaluate_step,
|
||||||
|
patience=patience,
|
||||||
|
)
|
||||||
|
|
||||||
def _validate_model_name(self, model_name):
|
def _validate_model_name(self, model_name):
|
||||||
if "vit" == model_name:
|
if "vit" == model_name:
|
||||||
|
@ -33,10 +45,8 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_model(self, model_name, num_labels):
|
def init_model(self, model_name, num_labels):
|
||||||
model = (
|
model = AutoModelForImageClassification.from_pretrained(
|
||||||
AutoModelForImageClassification.from_pretrained(
|
model_name, num_labels=num_labels
|
||||||
model_name, num_labels=num_labels
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
image_processor = AutoImageProcessor.from_pretrained(model_name)
|
image_processor = AutoImageProcessor.from_pretrained(model_name)
|
||||||
transforms = self.init_preprocessor(image_processor)
|
transforms = self.init_preprocessor(image_processor)
|
||||||
|
@ -100,9 +110,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
optimizer_name="adamW",
|
optimizer_name="adamW",
|
||||||
lr=self.lr,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||||
|
lr=self.lr,
|
||||||
print_steps=self.print_steps,
|
print_steps=self.print_steps,
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
|
@ -111,7 +121,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
val_dataloader=val_dataloader,
|
eval_dataloader=val_dataloader,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -149,7 +159,7 @@ class MultimodalDatasetTorch(Dataset):
|
||||||
],
|
],
|
||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}")
|
# print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.X)
|
return len(self.X)
|
||||||
|
@ -169,7 +179,8 @@ if __name__ == "__main__":
|
||||||
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
|
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
|
||||||
lXtr, lYtr = dataset.training()
|
lXtr, lYtr = dataset.training()
|
||||||
|
|
||||||
vg = VisualTransformerGen(model_name="vit")
|
vg = VisualTransformerGen(
|
||||||
|
model_name="vit", device="cuda", epochs=1000, evaluate_step=10, patience=100
|
||||||
|
)
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
vg.fit(lX, lY)
|
vg.fit(lX, lY)
|
||||||
print("lel")
|
|
||||||
|
|
46
main.py
46
main.py
|
@ -16,26 +16,46 @@ TODO:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def get_dataset(datasetname):
|
||||||
# Loading dataset ------------------------
|
assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
|
||||||
RCV_DATAPATH = expanduser(
|
RCV_DATAPATH = expanduser(
|
||||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||||
)
|
)
|
||||||
# dataset = MultiNewsDataset(expanduser(args.dataset_path))
|
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||||
# dataset = AmazonDataset(domains=args.domains,nrows=args.nrows,min_count=args.min_count,max_labels=args.max_labels)
|
if datasetname == "multinews":
|
||||||
dataset = (
|
dataset = MultiNewsDataset(
|
||||||
MultilingualDataset(dataset_name="rcv1-2")
|
expanduser(MULTINEWS_DATAPATH),
|
||||||
.load(RCV_DATAPATH)
|
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
||||||
.reduce_data(langs=["en", "it", "fr"], maxn=100)
|
)
|
||||||
)
|
elif datasetname == "amazon":
|
||||||
|
dataset = AmazonDataset(
|
||||||
|
domains=args.domains,
|
||||||
|
nrows=args.nrows,
|
||||||
|
min_count=args.min_count,
|
||||||
|
max_labels=args.max_labels,
|
||||||
|
)
|
||||||
|
elif datasetname == "rcv1-2":
|
||||||
|
dataset = (
|
||||||
|
MultilingualDataset(dataset_name="rcv1-2")
|
||||||
|
.load(RCV_DATAPATH)
|
||||||
|
.reduce_data(langs=["en", "it", "fr"], maxn=100)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return dataset
|
||||||
|
|
||||||
if isinstance(dataset, MultilingualDataset):
|
|
||||||
|
def main(args):
|
||||||
|
dataset = get_dataset(args.dataset)
|
||||||
|
if isinstance(dataset, MultilingualDataset) or isinstance(
|
||||||
|
dataset, MultiNewsDataset
|
||||||
|
):
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
lX_te, lY_te = dataset.test()
|
# lX_te, lY_te = dataset.test()
|
||||||
|
lX_te, lY_te = dataset.training()
|
||||||
else:
|
else:
|
||||||
_lX = dataset.dX
|
_lX = dataset.dX
|
||||||
_lY = dataset.dY
|
_lY = dataset.dY
|
||||||
# ----------------------------------------
|
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
|
@ -74,6 +94,7 @@ def main(args):
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
if args.load_trained is None:
|
if args.load_trained is None:
|
||||||
|
print("[NB: FORCE-SKIPPING MODEL SAVE]")
|
||||||
gfun.save()
|
gfun.save()
|
||||||
|
|
||||||
# if not args.load_model:
|
# if not args.load_model:
|
||||||
|
@ -98,6 +119,7 @@ if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
||||||
# Dataset parameters -------------------
|
# Dataset parameters -------------------
|
||||||
|
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
||||||
parser.add_argument("--domains", type=str, default="all")
|
parser.add_argument("--domains", type=str, default="all")
|
||||||
parser.add_argument("--nrows", type=int, default=10000)
|
parser.add_argument("--nrows", type=int, default=10000)
|
||||||
parser.add_argument("--min_count", type=int, default=10)
|
parser.add_argument("--min_count", type=int, default=10)
|
||||||
|
|
Loading…
Reference in New Issue