Compare commits

..

13 Commits

12 changed files with 476 additions and 99 deletions

1
.gitignore vendored
View File

@ -183,3 +183,4 @@ logger/*
explore_data.ipynb explore_data.ipynb
run.sh run.sh
wandb wandb
local_datasets

View File

@ -1,5 +1,6 @@
import sys import sys
import os import os
import xml.etree.ElementTree as ET
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
@ -8,13 +9,87 @@ import re
from dataManager.multilingualDataset import MultilingualDataset from dataManager.multilingualDataset import MultilingualDataset
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/") CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
LANGS = ["de", "en", "fr", "jp"] CLS_UNPROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-unprocessed/")
# LANGS = ["de", "en", "fr", "jp"]
LANGS = ["de", "en", "fr"]
DOMAINS = ["books", "dvd", "music"] DOMAINS = ["books", "dvd", "music"]
regex = r":\d+" regex = r":\d+"
subst = "" subst = ""
def load_unprocessed_cls(reduce_target_space=False):
data = {}
data_tr = []
data_te = []
c_tr = 0
c_te = 0
for lang in LANGS:
data[lang] = {}
for domain in DOMAINS:
data[lang][domain] = {}
print(f"lang: {lang}, domain: {domain}")
for split in ["train", "test"]:
domain_data = []
fdir = os.path.join(
CLS_UNPROCESSED_DATA_DIR, lang, domain, f"{split}.review"
)
tree = ET.parse(fdir)
root = tree.getroot()
for child in root:
if reduce_target_space:
rating = np.zeros(3, dtype=int)
original_rating = int(float(child.find("rating").text))
if original_rating < 3:
new_rating = 1
elif original_rating > 3:
new_rating = 3
else:
new_rating = 2
rating[new_rating - 1] = 1
# rating = new_rating
else:
rating = np.zeros(5, dtype=int)
rating[int(float(child.find("rating").text)) - 1] = 1
# rating = new_rating
# if split == "train":
# target_data = data_tr
# current_count = len(target_data)
# c_tr = +1
# else:
# target_data = data_te
# current_count = len(target_data)
# c_te = +1
domain_data.append(
# target_data.append(
{
"asin": child.find("asin").text
if child.find("asin") is not None
else None,
# "category": child.find("category").text
# if child.find("category") is not None
# else None,
"category": domain,
# "rating": child.find("rating").text
# if child.find("rating") is not None
# else None,
"rating": rating,
"title": child.find("title").text
if child.find("title") is not None
else None,
"text": child.find("text").text
if child.find("text") is not None
else None,
"summary": child.find("summary").text
if child.find("summary") is not None
else None,
"lang": lang,
}
)
data[lang][domain].update({split: domain_data})
return data
def load_cls(): def load_cls():
data = {} data = {}
for lang in LANGS: for lang in LANGS:
@ -24,7 +99,7 @@ def load_cls():
train = ( train = (
open( open(
os.path.join( os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed" CLS_UNPROCESSED_DATA_DIR, lang, domain, "train.processed"
), ),
"r", "r",
) )
@ -34,7 +109,7 @@ def load_cls():
test = ( test = (
open( open(
os.path.join( os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed" CLS_UNPROCESSED_DATA_DIR, lang, domain, "test.processed"
), ),
"r", "r",
) )
@ -59,18 +134,33 @@ def process_data(line):
if __name__ == "__main__": if __name__ == "__main__":
print(f"datapath: {CLS_PROCESSED_DATA_DIR}") print(f"datapath: {CLS_UNPROCESSED_DATA_DIR}")
data = load_cls() # data = load_cls()
multilingualDataset = MultilingualDataset(dataset_name="cls") data = load_unprocessed_cls(reduce_target_space=True)
for lang in LANGS: multilingualDataset = MultilingualDataset(dataset_name="webis-cls-unprocessed")
# TODO: just using book domain atm
Xtr = [text[0] for text in data[lang]["books"]["train"]]
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
Xte = [text[0] for text in data[lang]["books"]["test"]] for lang in LANGS:
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) # Xtr = [text["summary"] for text in data[lang]["books"]["train"]]
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]]) Xtr = [text["text"] for text in data[lang]["books"]["train"]]
Xtr += [text["text"] for text in data[lang]["dvd"]["train"]]
Xtr += [text["text"] for text in data[lang]["music"]["train"]]
Ytr =[text["rating"] for text in data[lang]["books"]["train"]]
Ytr += [text["rating"] for text in data[lang]["dvd"]["train"]]
Ytr += [text["rating"] for text in data[lang]["music"]["train"]]
Ytr = np.vstack(Ytr)
Xte = [text["text"] for text in data[lang]["books"]["test"]]
Xte += [text["text"] for text in data[lang]["dvd"]["test"]]
Xte += [text["text"] for text in data[lang]["music"]["test"]]
Yte = [text["rating"] for text in data[lang]["books"]["test"]]
Yte += [text["rating"] for text in data[lang]["dvd"]["test"]]
Yte += [text["rating"] for text in data[lang]["music"]["test"]]
Yte = np.vstack(Yte)
multilingualDataset.add( multilingualDataset.add(
lang=lang, lang=lang,
@ -82,5 +172,7 @@ if __name__ == "__main__":
te_ids=None, te_ids=None,
) )
multilingualDataset.save( multilingualDataset.save(
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") os.path.expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
)
) )

View File

@ -62,14 +62,29 @@ class gFunDataset:
) )
self.mlb = self.get_label_binarizer(self.labels) self.mlb = self.get_label_binarizer(self.labels)
elif "cls" in self.dataset_dir.lower(): # WEBIS-CLS (processed)
print(f"- Loading CLS dataset from {self.dataset_dir}") elif (
"cls" in self.dataset_dir.lower()
and "unprocessed" not in self.dataset_dir.lower()
):
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
self.dataset_name = "cls" self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual( self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows self.dataset_name, self.dataset_dir, self.nrows
) )
self.mlb = self.get_label_binarizer(self.labels) self.mlb = self.get_label_binarizer(self.labels)
# WEBIS-CLS (unprocessed)
elif (
"cls" in self.dataset_dir.lower()
and "unprocessed" in self.dataset_dir.lower()
):
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension() self.show_dimension()
return return

View File

@ -23,6 +23,7 @@ def get_dataset(dataset_name, args):
"rcv1-2", "rcv1-2",
"glami", "glami",
"cls", "cls",
"webis",
], "dataset not supported" ], "dataset not supported"
RCV_DATAPATH = expanduser( RCV_DATAPATH = expanduser(
@ -37,6 +38,10 @@ def get_dataset(dataset_name, args):
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
WEBIS_CLS = expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
)
if dataset_name == "multinews": if dataset_name == "multinews":
# TODO: convert to gFunDataset # TODO: convert to gFunDataset
raise NotImplementedError raise NotImplementedError
@ -91,6 +96,15 @@ def get_dataset(dataset_name, args):
is_multilabel=False, is_multilabel=False,
nrows=args.nrows, nrows=args.nrows,
) )
elif dataset_name == "webis":
dataset = gFunDataset(
dataset_dir=WEBIS_CLS,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else: else:
raise NotImplementedError raise NotImplementedError
return dataset return dataset

View File

@ -1,8 +1,9 @@
from joblib import Parallel, delayed from joblib import Parallel, delayed
from collections import defaultdict from collections import defaultdict
from evaluation.metrics import * # from evaluation.metrics import *
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score import numpy as np
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score, precision_score, recall_score
def evaluation_metrics(y, y_, clf_type): def evaluation_metrics(y, y_, clf_type):
@ -13,13 +14,17 @@ def evaluation_metrics(y, y_, clf_type):
# TODO: we need logits top_k_accuracy_score(y, y_, k=10), # TODO: we need logits top_k_accuracy_score(y, y_, k=10),
f1_score(y, y_, average="macro", zero_division=1), f1_score(y, y_, average="macro", zero_division=1),
f1_score(y, y_, average="micro"), f1_score(y, y_, average="micro"),
precision_score(y, y_, zero_division=1, average="macro"),
recall_score(y, y_, zero_division=1, average="macro"),
) )
elif clf_type == "multilabel": elif clf_type == "multilabel":
return ( return (
macroF1(y, y_), f1_score(y, y_, average="macro", zero_division=1),
microF1(y, y_), f1_score(y, y_, average="micro"),
macroK(y, y_), 0,
microK(y, y_), 0,
# macroK(y, y_),
# microK(y, y_),
) )
else: else:
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'") raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
@ -48,8 +53,10 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
if clf_type == "multilabel": if clf_type == "multilabel":
for lang in l_eval.keys(): for lang in l_eval.keys():
macrof1, microf1, macrok, microk = l_eval[lang] # macrof1, microf1, macrok, microk = l_eval[lang]
metrics.append([macrof1, microf1, macrok, microk]) # metrics.append([macrof1, microf1, macrok, microk])
macrof1, microf1, precision, recall = l_eval[lang]
metrics.append([macrof1, microf1, precision, recall])
if phase != "validation": if phase != "validation":
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}") print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
averages = np.mean(np.array(metrics), axis=0) averages = np.mean(np.array(metrics), axis=0)
@ -69,12 +76,15 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
# "acc10", # "accuracy-at-10", # "acc10", # "accuracy-at-10",
"MF1", # "macro-F1", "MF1", # "macro-F1",
"mF1", # "micro-F1", "mF1", # "micro-F1",
"precision",
"recall"
] ]
for lang in l_eval.keys(): for lang in l_eval.keys():
# acc, top5, top10, macrof1, microf1 = l_eval[lang] # acc, top5, top10, macrof1, microf1 = l_eval[lang]
acc, macrof1, microf1 = l_eval[lang] acc, macrof1, microf1, precision, recall= l_eval[lang]
# metrics.append([acc, top5, top10, macrof1, microf1]) # metrics.append([acc, top5, top10, macrof1, microf1])
metrics.append([acc, macrof1, microf1]) # metrics.append([acc, macrof1, microf1])
metrics.append([acc, macrof1, microf1, precision, recall])
for m, v in zip(_metrics, l_eval[lang]): for m, v in zip(_metrics, l_eval[lang]):
lang_metrics[m][lang] = v lang_metrics[m][lang] = v
@ -82,7 +92,8 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
if phase != "validation": if phase != "validation":
print( print(
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}" # f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}" # f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f} pr = {precision:.3f} re = {recall:.3f}"
) )
averages = np.mean(np.array(metrics), axis=0) averages = np.mean(np.array(metrics), axis=0)
if verbose: if verbose:

View File

@ -124,6 +124,16 @@ class GeneralizedFunnelling:
epochs=self.epochs, epochs=self.epochs,
attn_stacking_type=attn_stacking, attn_stacking_type=attn_stacking,
) )
self._model_id = get_unique_id(
self.dataset_name,
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc,
)
return self return self
if self.posteriors_vgf: if self.posteriors_vgf:
@ -317,7 +327,7 @@ class GeneralizedFunnelling:
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
vgf_config = vgf.get_config() vgf_config = vgf.get_config()
c.update(vgf_config) c.update({vgf_config["name"]: vgf_config})
gfun_config = { gfun_config = {
"id": self._model_id, "id": self._model_id,
@ -372,6 +382,7 @@ class GeneralizedFunnelling:
"rb", "rb",
) as vgf: ) as vgf:
first_tier_learners.append(pickle.load(vgf)) first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained VanillaFun VGF")
if self.multilingual_vgf: if self.multilingual_vgf:
with open( with open(
os.path.join( os.path.join(
@ -380,6 +391,7 @@ class GeneralizedFunnelling:
"rb", "rb",
) as vgf: ) as vgf:
first_tier_learners.append(pickle.load(vgf)) first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained Multilingual VGF")
if self.wce_vgf: if self.wce_vgf:
with open( with open(
os.path.join( os.path.join(
@ -388,20 +400,38 @@ class GeneralizedFunnelling:
"rb", "rb",
) as vgf: ) as vgf:
first_tier_learners.append(pickle.load(vgf)) first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained WCE VGF")
if self.textual_trf_vgf: if self.textual_trf_vgf:
with open( with open(
os.path.join( os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl" "models",
"vgfs",
"textual_transformer",
f"textualTransformerGen_{model_id}.pkl",
), ),
"rb", "rb",
) as vgf: ) as vgf:
first_tier_learners.append(pickle.load(vgf)) first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained Textual Transformer VGF")
if self.visual_trf_vgf:
with open(
os.path.join(
"models",
"vgfs",
"visual_transformer",
f"visualTransformerGen_{model_id}.pkl",
),
"rb",
print(f"- loaded trained Visual Transformer VGF"),
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if load_meta: if load_meta:
with open( with open(
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb" os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
) as f: ) as f:
metaclassifier = pickle.load(f) metaclassifier = pickle.load(f)
print(f"- loaded trained metaclassifier")
else: else:
metaclassifier = None metaclassifier = None
return first_tier_learners, metaclassifier, vectorizer return first_tier_learners, metaclassifier, vectorizer

View File

@ -45,11 +45,12 @@ class MT5ForSequenceClassification(nn.Module):
def save_pretrained(self, checkpoint_dir): def save_pretrained(self, checkpoint_dir):
torch.save(self.state_dict(), checkpoint_dir + ".pt") torch.save(self.state_dict(), checkpoint_dir + ".pt")
return return self
def from_pretrained(self, checkpoint_dir): def from_pretrained(self, checkpoint_dir):
checkpoint_dir += ".pt" checkpoint_dir += ".pt"
return self.load_state_dict(torch.load(checkpoint_dir)) self.load_state_dict(torch.load(checkpoint_dir))
return self
class TextualTransformerGen(ViewGen, TransformerGen): class TextualTransformerGen(ViewGen, TransformerGen):
@ -113,6 +114,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
model_name, num_labels=num_labels, output_hidden_states=True model_name, num_labels=num_labels, output_hidden_states=True
) )
else: else:
model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
return AutoModelForSequenceClassification.from_pretrained( return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True model_name, num_labels=num_labels, output_hidden_states=True
) )
@ -144,58 +146,60 @@ class TextualTransformerGen(ViewGen, TransformerGen):
self.model_name, num_labels=self.num_labels self.model_name, num_labels=self.num_labels
) )
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( self.model.to("cuda")
lX, lY, split=0.2, seed=42, modality="text"
)
tra_dataloader = self.build_dataloader( # tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
tr_lX, # lX, lY, split=0.2, seed=42, modality="text"
tr_lY, # )
processor_fn=self._tokenize, #
torchDataset=MultilingualDatasetTorch, # tra_dataloader = self.build_dataloader(
batch_size=self.batch_size, # tr_lX,
split="train", # tr_lY,
shuffle=True, # processor_fn=self._tokenize,
) # torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size,
val_dataloader = self.build_dataloader( # split="train",
val_lX, # shuffle=True,
val_lY, # )
processor_fn=self._tokenize, #
torchDataset=MultilingualDatasetTorch, # val_dataloader = self.build_dataloader(
batch_size=self.batch_size_eval, # val_lX,
split="val", # val_lY,
shuffle=False, # processor_fn=self._tokenize,
) # torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size_eval,
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}" # split="val",
# shuffle=False,
trainer = Trainer( # )
model=self.model, #
optimizer_name="adamW", # experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
lr=self.lr, #
device=self.device, # trainer = Trainer(
loss_fn=torch.nn.CrossEntropyLoss(), # model=self.model,
print_steps=self.print_steps, # optimizer_name="adamW",
evaluate_step=self.evaluate_step, # lr=self.lr,
patience=self.patience, # device=self.device,
experiment_name=experiment_name, # loss_fn=torch.nn.CrossEntropyLoss(),
checkpoint_path=os.path.join( # print_steps=self.print_steps,
"models", # evaluate_step=self.evaluate_step,
"vgfs", # patience=self.patience,
"transformer", # experiment_name=experiment_name,
self._format_model_name(self.model_name), # checkpoint_path=os.path.join(
), # "models",
vgf_name="textual_trf", # "vgfs",
classification_type=self.clf_type, # "trained_transformer",
n_jobs=self.n_jobs, # self._format_model_name(self.model_name),
scheduler_name=self.scheduler, # ),
) # vgf_name="textual_trf",
trainer.train( # classification_type=self.clf_type,
train_dataloader=tra_dataloader, # n_jobs=self.n_jobs,
eval_dataloader=val_dataloader, # scheduler_name=self.scheduler,
epochs=self.epochs, # )
) # trainer.train(
# train_dataloader=tra_dataloader,
# eval_dataloader=val_dataloader,
# epochs=self.epochs,
# )
if self.probabilistic: if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY) self.feature2posterior_projector.fit(self.transform(lX), lY)
@ -224,7 +228,6 @@ class TextualTransformerGen(ViewGen, TransformerGen):
with torch.no_grad(): with torch.no_grad():
for input_ids, lang in dataloader: for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device) input_ids = input_ids.to(self.device)
# TODO: check this
if isinstance(self.model, MT5ForSequenceClassification): if isinstance(self.model, MT5ForSequenceClassification):
batch_embeddings = self.model(input_ids).pooled.cpu().numpy() batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
else: else:
@ -277,4 +280,4 @@ class TextualTransformerGen(ViewGen, TransformerGen):
def get_config(self): def get_config(self):
c = super().get_config() c = super().get_config()
return {"textual_trf": c} return {"name": "textual-trasnformer VGF", "textual_trf": c}

View File

@ -65,3 +65,6 @@ class VanillaFunGen(ViewGen):
with open(_path, "wb") as f: with open(_path, "wb") as f:
pickle.dump(self, f) pickle.dump(self, f)
return self return self
def get_config(self):
return {"name": "Vanilla Funnelling VGF"}

View File

@ -186,4 +186,4 @@ class VisualTransformerGen(ViewGen, TransformerGen):
return self return self
def get_config(self): def get_config(self):
return {"visual_trf": super().get_config()} return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}

191
hf_trainer.py Normal file
View File

@ -0,0 +1,191 @@
import torch
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
TrainingArguments,
)
from gfun.vgfs.commons import Trainer
from datasets import load_dataset, DatasetDict
from transformers import Trainer
import transformers
import evaluate
transformers.logging.set_verbosity_error()
def init_callbacks(patience=-1, nosave=False):
callbacks = []
if patience != -1 and not nosave:
callbacks.append(transformers.EarlyStoppingCallback(early_stopping_patience=patience))
return callbacks
def init_model(model_name):
if model_name == "mbert":
hf_name = "bert-base-multilingual-cased"
elif model_name == "xlm-roberta":
hf_name = "xlm-roberta-base"
else:
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained(hf_name)
model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=3)
return tokenizer, model
def main(args):
tokenizer, model = init_model(args.model)
data = load_dataset(
"json",
data_files={
"train": "local_datasets/webis-cls/all-domains/train.json",
"test": "local_datasets/webis-cls/all-domains/test.json",
},
)
def process_sample(sample):
inputs = sample["text"]
ratings = [r - 1 for r in sample["rating"]]
targets = torch.zeros((len(inputs), 3), dtype=float)
lang_mapper = {
lang: lang_id for lang_id, lang in enumerate(set(sample["lang"]))
}
lang_ids = [lang_mapper[l] for l in sample["lang"]]
for i, r in enumerate(ratings):
targets[i][r - 1] = 1
model_inputs = tokenizer(inputs, max_length=512, truncation=True)
model_inputs["labels"] = targets
model_inputs["lang_ids"] = torch.tensor(lang_ids)
return model_inputs
data = data.map(
process_sample,
batched=True,
num_proc=4,
load_from_cache_file=True,
remove_columns=["text", "category", "rating", "summary", "title"],
)
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
data.set_format("torch")
data = DatasetDict(
{
"train": train_val_splits["train"],
"validation": train_val_splits["test"],
"test": data["test"],
}
)
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
callbacks = init_callbacks(args.patience, args.nosave)
f1_metric = evaluate.load("f1")
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
training_args = TrainingArguments(
output_dir=f"{args.model}-sentiment",
do_train=True,
evaluation_strategy="steps",
per_device_train_batch_size=args.batch,
per_device_eval_batch_size=args.batch,
gradient_accumulation_steps=args.gradacc,
eval_accumulation_steps=10,
learning_rate=args.lr,
weight_decay=0.1,
max_grad_norm=5.0,
num_train_epochs=args.epochs,
lr_scheduler_type=args.scheduler,
warmup_steps=1000,
logging_strategy="steps",
logging_first_step=True,
logging_steps=args.steplog,
seed=42,
fp16=args.fp16,
load_best_model_at_end=False if args.nosave else True,
save_strategy="no" if args.nosave else "steps",
save_total_limit=3,
eval_steps=args.stepeval,
run_name=f"{args.model}-sentiment-run",
disable_tqdm=False,
log_level="warning",
report_to=["wandb"] if args.wandb else "none",
optim="adamw_torch",
)
def compute_metrics(eval_preds):
preds = eval_preds.predictions.argmax(-1)
targets = eval_preds.label_ids.argmax(-1)
setting = "macro"
f1_score_macro = f1_metric.compute(
predictions=preds, references=targets, average="macro"
)
f1_score_micro = f1_metric.compute(
predictions=preds, references=targets, average="micro"
)
accuracy_score = accuracy_metric.compute(predictions=preds, references=targets)
precision_score = precision_metric.compute(
predictions=preds, references=targets, average=setting, zero_division=1
)
recall_score = recall_metric.compute(
predictions=preds, references=targets, average=setting, zero_division=1
)
results = {
"macro_f1score": f1_score_macro["f1"],
"micro_f1score": f1_score_micro["f1"],
"accuracy": accuracy_score["accuracy"],
"precision": precision_score["precision"],
"recall": recall_score["recall"],
}
results = {k: round(v, 4) for k, v in results.items()}
return results
if args.wandb:
import wandb
wandb.init(entity="andreapdr", project=f"gfun-senti-hf", name="mbert-sent", config=vars(args))
trainer = Trainer(
model=model,
args=training_args,
train_dataset=data["train"],
eval_dataset=data["validation"],
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
)
print("- Training:")
trainer.train()
print("- Testing:")
test_results = trainer.evaluate(eval_dataset=data["test"])
print(test_results)
exit()
if __name__ == "__main__":
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, metavar="", default="mbert")
parser.add_argument("--lr", type=float, metavar="", default=1e-5, help="Set learning rate",)
parser.add_argument("--scheduler", type=str, metavar="", default="linear", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
parser.add_argument("--batch", type=int, metavar="", default=16, help="Set batch size")
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
parser.add_argument("--steplog", type=int, metavar="", default=100, help="Log training every n steps")
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
# parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
args = parser.parse_args()
main(args)

13
main.py
View File

@ -1,8 +1,5 @@
import os
import wandb import wandb
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser from argparse import ArgumentParser
from time import time from time import time
@ -12,16 +9,9 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
""" """
TODO: TODO:
- Transformers VGFs:
- scheduler with warmup and cosine
- freeze params method
- General: - General:
[!] zero-shot setup [!] zero-shot setup
- CLS dataset is loading only "books" domain data - CLS dataset is loading only "books" domain data
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
- Attention Aggregator:
- experiment with weight init of Attention-aggregator
- FFNN posterior-probabilities' dependent
- Docs: - Docs:
- add documentations sphinx - add documentations sphinx
""" """
@ -150,7 +140,6 @@ def main(args):
wandb.log(gfun_res) wandb.log(gfun_res)
log_barplot_wandb(lang_metrics_gfun, title_affix="per language") log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
if __name__ == "__main__": if __name__ == "__main__":
@ -178,7 +167,7 @@ if __name__ == "__main__":
parser.add_argument("--features", action="store_false") parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean") parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters --------------- # transformer parameters ---------------
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--textual_trf_name", type=str, default="mbert") parser.add_argument("--textual_trf_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--eval_batch_size", type=int, default=128)

28
run-senti.sh Normal file
View File

@ -0,0 +1,28 @@
#!bin/bash
config="-m"
echo "[Running gFun config: $config]"
epochs=100
njobs=-1
clf=singlelabel
patience=5
eval_every=5
text_len=256
text_lr=1e-4
bsize=64
txt_model=mbert
python main.py $config \
-d webis \
--epochs $epochs \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model \
--load_trained webis_pmwt_mean_230621