branching for rai

This commit is contained in:
Andrea Pedrotti 2023-10-05 15:39:49 +02:00
parent fbd740fabd
commit 234b6031b1
9 changed files with 147 additions and 411 deletions

View File

@ -5,12 +5,6 @@ from sklearn.metrics import mean_absolute_error
from os.path import join
"""
MEA and classification is meaningful only in "ordinal" tasks e.g., sentiment classification.
Otherwise the distance between the categories has no semantics!
- NB: we want to get the macro-averaged class specific MAE!
"""
def main():
# SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"]
@ -24,11 +18,9 @@ def main():
print(df)
def evalaute(setting):
def evalaute(setting, result_file="preds.gfun.csv"):
result_dir = "results"
# result_file = f"lang-specific.gfun.{setting}.webis.csv"
result_file = f"lang-specific.mbert.webis.csv"
# print(f"- reading from: {result_file}")
print(f"- reading from: {result_file}")
df = pd.read_csv(join(result_dir, result_file))
langs = df.langs.unique()
res = []

View File

@ -14,7 +14,7 @@ class CsvLogger:
# os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True)
# return
def log_lang_results(self, results: dict, config="gfun-lello"):
def log_lang_results(self, results: dict, config="gfun-default", notes=None):
df = pd.DataFrame.from_dict(results, orient="columns")
df["config"] = config["gFun"]["simple_id"]
df["aggfunc"] = config["gFun"]["aggfunc"]
@ -22,6 +22,7 @@ class CsvLogger:
df["id"] = config["gFun"]["id"]
df["optimc"] = config["gFun"]["optimc"]
df["timing"] = config["gFun"]["timing"]
df["notes"] = notes
with open(self.outfile, 'a') as f:
df.to_csv(f, mode='a', header=f.tell()==0)

View File

@ -16,7 +16,17 @@ from dataManager.multilingualDataset import MultilingualDataset
class SimpleGfunDataset:
def __init__(self, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None):
def __init__(
self,
dataset_name=None,
datadir="~/datasets/rai/csv/",
textual=True,
visual=False,
multilabel=False,
set_tr_langs=None,
set_te_langs=None
):
self.name = dataset_name
self.datadir = os.path.expanduser(datadir)
self.textual = textual
self.visual = visual
@ -41,11 +51,15 @@ class SimpleGfunDataset:
print(f"tr: {tr} - va: {va} - te: {te}")
def load_csv(self, set_tr_langs, set_te_langs):
# _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.balanced.csv")).sample(100, random_state=42)
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) # TODO stratify on lang or label?
# test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
test = pd.read_csv(os.path.join(self.datadir, "test.balanced.csv")).sample(100, random_state=42)
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
try:
stratified = "class"
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label)
except:
stratified = "lang"
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang)
print(f"- dataset stratified by {stratified}")
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
self._set_langs (train, test, set_tr_langs, set_te_langs)
self._set_labels(_data_tr)
self.full_train = _data_tr
@ -56,8 +70,6 @@ class SimpleGfunDataset:
return
def _set_labels(self, data):
# self.labels = [i for i in range(28)] # todo hard-coded for rai
# self.labels = [i for i in range(3)] # TODO hard coded for sentimnet
self.labels = sorted(list(data.label.unique()))
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
@ -70,17 +82,17 @@ class SimpleGfunDataset:
print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]")
self.te_langs = self.te_langs.intersection(set(set_te_langs))
self.all_langs = self.tr_langs.union(self.te_langs)
return self.tr_langs, self.te_langs, self.all_langs
def _set_datalang(self, data: pd.DataFrame):
return {lang: data[data.lang == lang] for lang in self.all_langs}
def training(self, merge_validation=False, mask_number=False, target_as_csr=False):
# TODO some additional pre-processing on the textual data?
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXtr = {
lang: {"text": apply_mask(self.train_data[lang].text.tolist())} # TODO inserting dict for textual data - we still have to manage visual
lang: {"text": apply_mask(self.train_data[lang].text.tolist())}
for lang in self.tr_langs
}
if merge_validation:
@ -103,7 +115,6 @@ class SimpleGfunDataset:
return lXtr, lYtr
def test(self, mask_number=False, target_as_csr=False):
# TODO some additional pre-processing on the textual data?
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXte = {
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
@ -162,62 +173,12 @@ class gFunDataset:
return mlb
def _load_dataset(self):
if "glami" in self.dataset_dir.lower():
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
self.dataset_name = "glami"
self.dataset, self.labels, self.data_langs = self._load_glami(
self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
elif "rcv" in self.dataset_dir.lower():
print(f"- Loading RCV1-2 dataset from {self.dataset_dir}")
self.dataset_name = "rcv1-2"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
elif "jrc" in self.dataset_dir.lower():
print(f"- Loading JRC dataset from {self.dataset_dir}")
self.dataset_name = "jrc"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
# WEBIS-CLS (processed)
elif (
"cls" in self.dataset_dir.lower()
and "unprocessed" not in self.dataset_dir.lower()
):
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
# WEBIS-CLS (unprocessed)
elif (
"cls" in self.dataset_dir.lower()
and "unprocessed" in self.dataset_dir.lower()
):
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
elif "rai" in self.dataset_dir.lower():
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
self.dataset_name = "rai"
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
dataset_dir="~/datasets/rai/csv/train-split-rai.csv",
nrows=self.nrows)
self.mlb = self.get_label_binarizer(self.labels)
print(f"- Loading dataset from {self.dataset_dir}")
self.dataset_name = "rai"
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
dataset_dir=self.dataset_dir,
nrows=self.nrows)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension()
return
@ -232,15 +193,12 @@ class gFunDataset:
else:
print(f"-- Labels: {len(self.labels)}")
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"):
if "csv" in dataset_dir:
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv(
# path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
#path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
path_tr="~/datasets/rai/csv/train-split-rai.csv",
path_te="~/datasets/rai/csv/test-split-rai.csv")
else:
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
old_dataset = MultilingualDataset(dataset_name="rai").from_csv(
path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")),
path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv"))
)
if nrows is not None:
if dataset_name == "cls":
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
@ -365,27 +323,4 @@ if __name__ == "__main__":
data_rai = SimpleGfunDataset()
lXtr, lYtr = data_rai.training(mask_number=False)
lXte, lYte = data_rai.test(mask_number=False)
exit()
# import os
# GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
# RCV_DATAPATH = os.path.expanduser(
# "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
# )
# JRC_DATAPATH = os.path.expanduser(
# "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
# )
# print("Hello gFunDataset")
# dataset = gFunDataset(
# dataset_dir=JRC_DATAPATH,
# data_langs=None,
# is_textual=True,
# is_visual=True,
# is_multilabel=False,
# labels=None,
# nrows=13,
# )
# lXtr, lYtr = dataset.training()
# lXte, lYte = dataset.test()
# exit(0)
exit()

View File

@ -1,7 +1,5 @@
from os.path import expanduser, join
from dataManager.gFunDataset import gFunDataset, SimpleGfunDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset
from dataManager.gFunDataset import SimpleGfunDataset
def load_from_pickle(path, dataset_name, nrows):
@ -16,119 +14,14 @@ def load_from_pickle(path, dataset_name, nrows):
return loaded
def get_dataset(dataset_name, args):
assert dataset_name in [
"multinews",
"amazon",
"rcv1-2",
"glami",
"cls",
"webis",
"rai",
], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
def get_dataset(datasetp_path, args):
dataset = SimpleGfunDataset(
dataset_name="rai",
datadir=datasetp_path,
textual=True,
visual=False,
multilabel=False,
set_tr_langs=args.tr_langs,
set_te_langs=args.te_langs
)
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
WEBIS_CLS = expanduser(
# "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
"~/datasets/cls-acl10-unprocessed/csv"
)
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
if dataset_name == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif dataset_name == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif dataset_name == "jrc":
dataset = gFunDataset(
dataset_dir=JRC_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif dataset_name == "rcv1-2":
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif dataset_name == "glami":
if args.save_dataset is False:
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
else:
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
dataset.save_as_pickle(GLAMI_DATAPATH)
elif dataset_name == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
elif dataset_name == "webis":
dataset = SimpleGfunDataset(
datadir=WEBIS_CLS,
textual=True,
visual=False,
multilabel=False,
set_tr_langs=args.tr_langs,
set_te_langs=args.te_langs
)
# dataset = gFunDataset(
# dataset_dir=WEBIS_CLS,
# is_textual=True,
# is_visual=False,
# is_multilabel=False,
# nrows=args.nrows,
# )
elif dataset_name == "rai":
dataset = SimpleGfunDataset(
datadir="~/datasets/rai/csv",
textual=True,
visual=False,
multilabel=False
)
# dataset = gFunDataset(
# dataset_dir=RAI_DATAPATH,
# is_textual=True,
# is_visual=False,
# is_multilabel=False,
# nrows=args.nrows
# )
else:
raise NotImplementedError
return dataset

View File

@ -8,7 +8,6 @@ from gfun.vgfs.learners.svms import MetaClassifier, get_learner
from gfun.vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from gfun.vgfs.vanillaFun import VanillaFunGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.wceGen import WceGen
@ -19,7 +18,6 @@ class GeneralizedFunnelling:
wce,
multilingual,
textual_transformer,
visual_transformer,
langs,
num_labels,
classification_type,
@ -29,12 +27,9 @@ class GeneralizedFunnelling:
eval_batch_size,
max_length,
textual_lr,
visual_lr,
epochs,
patience,
evaluate_step,
textual_transformer_name,
visual_transformer_name,
optimc,
device,
load_trained,
@ -42,13 +37,14 @@ class GeneralizedFunnelling:
probabilistic,
aggfunc,
load_meta,
trained_text_trf=None,
textual_transformer_name=None,
):
# Setting VFGs -----------
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.textual_trf_vgf = textual_transformer
self.visual_trf_vgf = visual_transformer
self.probabilistic = probabilistic
self.num_labels = num_labels
self.clf_type = classification_type
@ -58,6 +54,7 @@ class GeneralizedFunnelling:
self.cached = True
# Textual Transformer VGF params ----------
self.textual_trf_name = textual_transformer_name
self.trained_text_trf = trained_text_trf
self.epochs = epochs
self.textual_trf_lr = textual_lr
self.textual_scheduler = "ReduceLROnPlateau"
@ -68,10 +65,6 @@ class GeneralizedFunnelling:
self.patience = patience
self.evaluate_step = evaluate_step
self.device = device
# Visual Transformer VGF params ----------
self.visual_trf_name = visual_transformer_name
self.visual_trf_lr = visual_lr
self.visual_scheduler = "ReduceLROnPlateau"
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
@ -131,7 +124,6 @@ class GeneralizedFunnelling:
self.multilingual_vgf,
self.wce_vgf,
self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc,
)
return self
@ -174,26 +166,10 @@ class GeneralizedFunnelling:
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
saved_model=self.trained_text_trf,
)
self.first_tier_learners.append(transformer_vgf)
if self.visual_trf_vgf:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
lr=self.visual_trf_lr,
scheduler=self.visual_scheduler,
epochs=self.epochs,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
)
self.first_tier_learners.append(visual_trasformer_vgf)
if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator(
@ -219,7 +195,6 @@ class GeneralizedFunnelling:
self.multilingual_vgf,
self.wce_vgf,
self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc,
)
print(f"- model id: {self._model_id}")
@ -309,16 +284,29 @@ class GeneralizedFunnelling:
return aggregated
def _aggregate_mean(self, first_tier_projections):
aggregated = {
lang: np.zeros(data.shape)
for lang, data in first_tier_projections[0].items()
}
# aggregated = {
# lang: np.zeros(data.shape)
# for lang, data in first_tier_projections[0].items()
# }
aggregated = {}
for lang_projections in first_tier_projections:
for lang, projection in lang_projections.items():
if lang not in aggregated:
aggregated[lang] = np.zeros(projection.shape)
aggregated[lang] += projection
for lang, projection in aggregated.items():
aggregated[lang] /= len(first_tier_projections)
def get_denom(lang, projs):
den = 0
for proj in projs:
if lang in proj:
den += 1
return den
for lang, _ in aggregated.items():
# aggregated[lang] /= len(first_tier_projections)
aggregated[lang] /= get_denom(lang, first_tier_projections)
return aggregated
@ -416,18 +404,6 @@ class GeneralizedFunnelling:
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained Textual Transformer VGF")
if self.visual_trf_vgf:
with open(
os.path.join(
"models",
"vgfs",
"visual_transformer",
f"visualTransformerGen_{model_id}.pkl",
),
"rb",
print(f"- loaded trained Visual Transformer VGF"),
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if load_meta:
with open(
@ -482,7 +458,6 @@ def get_unique_id(
multilingual,
wce,
textual_transformer,
visual_transformer,
aggfunc,
):
from datetime import datetime
@ -493,6 +468,5 @@ def get_unique_id(
model_id += "m" if multilingual else ""
model_id += "w" if wce else ""
model_id += "t" if textual_transformer else ""
model_id += "v" if visual_transformer else ""
model_id += f"_{aggfunc}"
return f"{model_id}_{now}"

View File

@ -173,9 +173,18 @@ class NaivePolylingualClassifier:
:return: a dictionary of probabilities that each document belongs to each class
"""
assert self.model is not None, "predict called before fit"
assert set(lX.keys()).issubset(
set(self.model.keys())
), "unknown languages requested in decision function"
if not set(lX.keys()).issubset(set(self.model.keys())):
langs = set(lX.keys()).intersection(set(self.model.keys()))
scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs)
# res = {lang: None for lang in lX.keys()}
# for i, lang in enumerate(langs):
# res[lang] = scores[i]
# return res
return {lang: scores[i] for i, lang in enumerate(langs)}
# assert set(lX.keys()).issubset(
# set(self.model.keys())
# ), "unknown languages requested in decision function"
langs = list(lX.keys())
scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs

View File

@ -73,6 +73,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
patience=5,
classification_type="multilabel",
scheduler="ReduceLROnPlateau",
saved_model = None
):
super().__init__(
self._validate_model_name(model_name),
@ -91,6 +92,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
n_jobs=n_jobs,
verbose=verbose,
)
self.saved_model = saved_model
self.clf_type = classification_type
self.fitted = False
print(
@ -109,26 +111,25 @@ class TextualTransformerGen(ViewGen, TransformerGen):
else:
raise NotImplementedError
def load_pretrained_model(self, model_name, num_labels):
def load_pretrained_model(self, model_name, num_labels, saved_model=None):
if model_name == "google/mt5-small":
return MT5ForSequenceClassification(
model_name, num_labels=num_labels, output_hidden_states=True
)
else:
# model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
# model_name = "hf_models/mbert-rai-fewshot-second/checkpoint-19000" # TODO hardcoded to pre-traiend mbert
# model_name = "hf_models/mbert-sentiment/checkpoint-1150" # TODO hardcoded to pre-traiend mbert
model_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
if saved_model:
model_name = saved_model
else:
model_name = "google/bert-base-multilingual-cased"
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
def load_tokenizer(self, model_name):
# model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
return AutoTokenizer.from_pretrained(model_name)
def init_model(self, model_name, num_labels):
return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer(
def init_model(self, model_name, num_labels, saved_model):
return self.load_pretrained_model(model_name, num_labels, saved_model), self.load_tokenizer(
model_name
)
@ -148,64 +149,11 @@ class TextualTransformerGen(ViewGen, TransformerGen):
_l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1]
self.model, self.tokenizer = self.init_model(
self.model_name, num_labels=self.num_labels
self.model_name, num_labels=self.num_labels, saved_model=self.saved_model,
)
self.model.to("cuda")
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
# lX, lY, split=0.2, seed=42, modality="text"
# )
#
# tra_dataloader = self.build_dataloader(
# tr_lX,
# tr_lY,
# processor_fn=self._tokenize,
# torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size,
# split="train",
# shuffle=True,
# )
#
# val_dataloader = self.build_dataloader(
# val_lX,
# val_lY,
# processor_fn=self._tokenize,
# torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size_eval,
# split="val",
# shuffle=False,
# )
#
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
#
# trainer = Trainer(
# model=self.model,
# optimizer_name="adamW",
# lr=self.lr,
# device=self.device,
# loss_fn=torch.nn.CrossEntropyLoss(),
# print_steps=self.print_steps,
# evaluate_step=self.evaluate_step,
# patience=self.patience,
# experiment_name=experiment_name,
# checkpoint_path=os.path.join(
# "models",
# "vgfs",
# "trained_transformer",
# self._format_model_name(self.model_name),
# ),
# vgf_name="textual_trf",
# classification_type=self.clf_type,
# n_jobs=self.n_jobs,
# scheduler_name=self.scheduler,
# )
# trainer.train(
# train_dataloader=tra_dataloader,
# eval_dataloader=val_dataloader,
# epochs=self.epochs,
# )
if self.probabilistic:
transformed = self.transform(lX)
self.feature2posterior_projector.fit(transformed, lY)

View File

@ -1,4 +1,4 @@
from os.path import expanduser
from os.path import expanduser, join
import torch
from transformers import (
@ -19,13 +19,8 @@ import pandas as pd
transformers.logging.set_verbosity_error()
IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"]
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] # "lang"
WEBIS_D_COLUMNS = ['Unnamed: 0', 'asin', 'category', 'original_rating', 'label', 'title', 'text', 'summary'] # "lang"
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"]
MAX_LEN = 128
# DATASET_NAME = "rai"
# DATASET_NAME = "rai-multilingual-2000"
# DATASET_NAME = "webis-cls"
def init_callbacks(patience=-1, nosave=False):
@ -35,13 +30,17 @@ def init_callbacks(patience=-1, nosave=False):
return callbacks
def init_model(model_name, nlabels):
def init_model(model_name, nlabels, saved_model=None):
if model_name == "mbert":
# hf_name = "bert-base-multilingual-cased"
hf_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
# hf_name = "hf_models/mbert-rai-fewshot-second/checkpoint-9000"
if saved_model is None:
hf_name = "bert-base-multilingual-cased"
else:
hf_name = saved_model
elif model_name == "xlm-roberta":
hf_name = "xlm-roberta-base"
if saved_model is None:
hf_name = "xlm-roberta-base"
else:
hf_name = saved_model
else:
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained(hf_name)
@ -50,43 +49,41 @@ def init_model(model_name, nlabels):
def main(args):
tokenizer, model = init_model(args.model, args.nlabels)
saved_model = args.savedmodel
trainlang = args.trainlangs
datapath = args.datapath
tokenizer, model = init_model(args.model, args.nlabels, saved_model=saved_model)
data = load_dataset(
"csv",
data_files = {
"train": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/train.balanced.csv"),
"test": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/test.balanced.csv")
# "train": expanduser(f"~/datasets/rai/csv/train-{DATASET_NAME}.csv"),
# "test": expanduser(f"~/datasets/rai/csv/test-{DATASET_NAME}.csv")
# "train": expanduser(f"~/datasets/rai/csv/train.small.csv"),
# "test": expanduser(f"~/datasets/rai/csv/test.small.csv")
"train": expanduser(join(datapath, "train.csv")),
"test": expanduser(join(datapath, "test.small.csv"))
}
)
def filter_dataset(dataset, lang):
indices = [i for i, l in enumerate(dataset["lang"]) if l == lang]
dataset = dataset.select(indices)
return dataset
if trainlang is not None:
data["train"] = filter_dataset(data["train"], lang=trainlang)
def process_sample_rai(sample):
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
labels = sample["label"]
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True)
model_inputs["labels"] = labels
return model_inputs
def process_sample_webis(sample):
inputs = sample["text"]
labels = sample["label"]
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
model_inputs["labels"] = labels
return model_inputs
data = data.map(
# process_sample_rai,
process_sample_webis,
process_sample_rai,
batched=True,
num_proc=4,
load_from_cache_file=True,
# remove_columns=RAI_D_COLUMNS,
remove_columns=WEBIS_D_COLUMNS,
remove_columns=RAI_D_COLUMNS,
)
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
data.set_format("torch")
@ -107,8 +104,8 @@ def main(args):
recall_metric = evaluate.load("recall")
training_args = TrainingArguments(
# output_dir=f"hf_models/{args.model}-rai",
output_dir=f"hf_models/{args.model}-sentiment-balanced",
output_dir=f"hf_models/{args.model}-fewshot-full" if trainlang is None else f"hf_models/{args.model}-zeroshot-full",
run_name="model-zeroshot" if trainlang is not None else "model-fewshot",
do_train=True,
evaluation_strategy="steps",
per_device_train_batch_size=args.batch,
@ -130,8 +127,6 @@ def main(args):
save_strategy="no" if args.nosave else "steps",
save_total_limit=2,
eval_steps=args.stepeval,
# run_name=f"{args.model}-rai-stratified",
run_name=f"{args.model}-sentiment",
disable_tqdm=False,
log_level="warning",
report_to=["wandb"] if args.wandb else "none",
@ -142,7 +137,6 @@ def main(args):
def compute_metrics(eval_preds):
preds = eval_preds.predictions.argmax(-1)
# targets = eval_preds.label_ids.argmax(-1)
targets = eval_preds.label_ids
setting = "macro"
f1_score_macro = f1_metric.compute(
@ -170,7 +164,9 @@ def main(args):
if args.wandb:
import wandb
wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-rai", config=vars(args))
wandb.init(entity="andreapdr", project=f"gfun",
name="model-zeroshot-full" if trainlang is not None else "model-fewshot-full",
config=vars(args))
trainer = Trainer(
model=model,
@ -188,17 +184,21 @@ def main(args):
trainer.train()
print("- Testing:")
test_results = trainer.evaluate(eval_dataset=data["test"], metric_key_prefix="test")
test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test")
pprint(test_results.metrics)
save_preds(data["test"], test_results.predictions)
save_preds(data["test"], test_results.predictions, trainlang)
exit()
def save_preds(dataset, predictions):
def save_preds(dataset, predictions, trainlang=None):
df = pd.DataFrame()
df["langs"] = dataset["lang"]
df["labels"] = dataset["labels"]
df["preds"] = predictions.argmax(axis=1)
df.to_csv("results/lang-specific.mbert.webis.csv", index=False)
if trainlang is not None:
df.to_csv(f"results/zeroshot.{trainlang}.model.csv", index=False)
else:
df.to_csv("results/fewshot.model.csv", index=False)
return
@ -210,16 +210,18 @@ if __name__ == "__main__":
parser.add_argument("--nlabels", type=int, metavar="", default=28)
parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",)
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
parser.add_argument("--epochs", type=int, metavar="", default=10, help="Set epochs")
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
parser.add_argument("--trainlang", default=None, type=str, help="set training language for zero-shot experiments" )
parser.add_argument("--datapath", type=str, default="data", help="path to the csv dataset. Dir should contain both a train.csv and a test.csv file")
parser.add_argument("--savedmodel", type=str, default="hf_models/mbert-rai-fewshot-second/checkpoint-9000")
args = parser.parse_args()
main(args)

30
main.py
View File

@ -10,8 +10,6 @@ import pandas as pd
"""
TODO:
- General:
[!] zero-shot setup
- Docs:
- add documentations sphinx
"""
@ -27,13 +25,11 @@ def get_config_name(args):
config_name += "M+"
if args.textual_transformer:
config_name += f"TT_{args.textual_trf_name}+"
if args.visual_transformer:
config_name += f"VT_{args.visual_trf_name}+"
return config_name.rstrip("+")
def main(args):
dataset = get_dataset(args.dataset, args)
dataset = get_dataset(args.datadir, args)
lX, lY = dataset.training(merge_validation=True)
lX_te, lY_te = dataset.test()
@ -47,13 +43,12 @@ def main(args):
args.multilingual,
args.multilingual,
args.textual_transformer,
args.visual_transformer,
]
), "At least one of VGF must be True"
gfun = GeneralizedFunnelling(
# dataset params ----------------------
dataset_name=args.dataset,
dataset_name=dataset,
langs=dataset.langs(),
num_labels=dataset.num_labels(),
classification_type=args.clf_type,
@ -67,24 +62,15 @@ def main(args):
# Transformer VGF params --------------
textual_transformer=args.textual_transformer,
textual_transformer_name=args.textual_trf_name,
trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size,
epochs=args.epochs,
textual_lr=args.textual_lr,
visual_lr=args.visual_lr,
max_length=args.max_length,
patience=args.patience,
evaluate_step=args.evaluate_step,
device=args.device,
# Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_trf_name,
# batch_size=args.batch_size,
# epochs=args.epochs,
# lr=args.lr,
# patience=args.patience,
# evaluate_step=args.evaluate_step,
# device="cuda",
# General params ---------------------
probabilistic=args.features,
aggfunc=args.aggfunc,
@ -145,7 +131,7 @@ def main(args):
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
config["gFun"]["timing"] = f"{timeval - tinit:.2f}"
csvlogger = CsvLogger(outfile="results/random.log.csv").log_lang_results(lang_metrics_gfun, config)
csvlogger = CsvLogger(outfile="results/gfun.log.csv").log_lang_results(lang_metrics_gfun, config, notes="")
save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"])
@ -162,7 +148,7 @@ def save_preds(preds, targets, config="unk", dataset="unk"):
df["langs"] = _langs
df["labels"] = _targets
df["preds"] = _preds
df.to_csv(f"results/lang-specific.gfun.{config}.{dataset}.csv", index=False)
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False)
if __name__ == "__main__":
@ -174,7 +160,7 @@ if __name__ == "__main__":
parser.add_argument("--tr_langs", nargs="+", default=None)
parser.add_argument("--te_langs", nargs="+", default=None)
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
parser.add_argument("-d", "--datadir", type=str, default=None, help="dir to dataset. It should contain both a train.csv and a test.csv file")
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10)
@ -186,7 +172,6 @@ if __name__ == "__main__":
parser.add_argument("-m", "--multilingual", action="store_true")
parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--textual_transformer", action="store_true")
parser.add_argument("-v", "--visual_transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=-1)
parser.add_argument("--optimc", action="store_true")
parser.add_argument("--features", action="store_false")
@ -200,9 +185,6 @@ if __name__ == "__main__":
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters --------------
parser.add_argument("--visual_trf_name", type=str, default="vit")
parser.add_argument("--visual_lr", type=float, default=1e-4)
# logging
parser.add_argument("--wandb", action="store_true")