branching for rai

This commit is contained in:
Andrea Pedrotti 2023-10-05 15:39:49 +02:00
parent fbd740fabd
commit 234b6031b1
9 changed files with 147 additions and 411 deletions

View File

@ -5,12 +5,6 @@ from sklearn.metrics import mean_absolute_error
from os.path import join from os.path import join
"""
MEA and classification is meaningful only in "ordinal" tasks e.g., sentiment classification.
Otherwise the distance between the categories has no semantics!
- NB: we want to get the macro-averaged class specific MAE!
"""
def main(): def main():
# SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"] # SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"]
@ -24,11 +18,9 @@ def main():
print(df) print(df)
def evalaute(setting): def evalaute(setting, result_file="preds.gfun.csv"):
result_dir = "results" result_dir = "results"
# result_file = f"lang-specific.gfun.{setting}.webis.csv" print(f"- reading from: {result_file}")
result_file = f"lang-specific.mbert.webis.csv"
# print(f"- reading from: {result_file}")
df = pd.read_csv(join(result_dir, result_file)) df = pd.read_csv(join(result_dir, result_file))
langs = df.langs.unique() langs = df.langs.unique()
res = [] res = []

View File

@ -14,7 +14,7 @@ class CsvLogger:
# os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True) # os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True)
# return # return
def log_lang_results(self, results: dict, config="gfun-lello"): def log_lang_results(self, results: dict, config="gfun-default", notes=None):
df = pd.DataFrame.from_dict(results, orient="columns") df = pd.DataFrame.from_dict(results, orient="columns")
df["config"] = config["gFun"]["simple_id"] df["config"] = config["gFun"]["simple_id"]
df["aggfunc"] = config["gFun"]["aggfunc"] df["aggfunc"] = config["gFun"]["aggfunc"]
@ -22,6 +22,7 @@ class CsvLogger:
df["id"] = config["gFun"]["id"] df["id"] = config["gFun"]["id"]
df["optimc"] = config["gFun"]["optimc"] df["optimc"] = config["gFun"]["optimc"]
df["timing"] = config["gFun"]["timing"] df["timing"] = config["gFun"]["timing"]
df["notes"] = notes
with open(self.outfile, 'a') as f: with open(self.outfile, 'a') as f:
df.to_csv(f, mode='a', header=f.tell()==0) df.to_csv(f, mode='a', header=f.tell()==0)

View File

@ -16,7 +16,17 @@ from dataManager.multilingualDataset import MultilingualDataset
class SimpleGfunDataset: class SimpleGfunDataset:
def __init__(self, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None): def __init__(
self,
dataset_name=None,
datadir="~/datasets/rai/csv/",
textual=True,
visual=False,
multilabel=False,
set_tr_langs=None,
set_te_langs=None
):
self.name = dataset_name
self.datadir = os.path.expanduser(datadir) self.datadir = os.path.expanduser(datadir)
self.textual = textual self.textual = textual
self.visual = visual self.visual = visual
@ -41,11 +51,15 @@ class SimpleGfunDataset:
print(f"tr: {tr} - va: {va} - te: {te}") print(f"tr: {tr} - va: {va} - te: {te}")
def load_csv(self, set_tr_langs, set_te_langs): def load_csv(self, set_tr_langs, set_te_langs):
# _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv")) _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.balanced.csv")).sample(100, random_state=42) try:
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) # TODO stratify on lang or label? stratified = "class"
# test = pd.read_csv(os.path.join(self.datadir, "test.small.csv")) train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label)
test = pd.read_csv(os.path.join(self.datadir, "test.balanced.csv")).sample(100, random_state=42) except:
stratified = "lang"
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang)
print(f"- dataset stratified by {stratified}")
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
self._set_langs (train, test, set_tr_langs, set_te_langs) self._set_langs (train, test, set_tr_langs, set_te_langs)
self._set_labels(_data_tr) self._set_labels(_data_tr)
self.full_train = _data_tr self.full_train = _data_tr
@ -56,8 +70,6 @@ class SimpleGfunDataset:
return return
def _set_labels(self, data): def _set_labels(self, data):
# self.labels = [i for i in range(28)] # todo hard-coded for rai
# self.labels = [i for i in range(3)] # TODO hard coded for sentimnet
self.labels = sorted(list(data.label.unique())) self.labels = sorted(list(data.label.unique()))
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None): def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
@ -70,17 +82,17 @@ class SimpleGfunDataset:
print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]") print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]")
self.te_langs = self.te_langs.intersection(set(set_te_langs)) self.te_langs = self.te_langs.intersection(set(set_te_langs))
self.all_langs = self.tr_langs.union(self.te_langs) self.all_langs = self.tr_langs.union(self.te_langs)
return self.tr_langs, self.te_langs, self.all_langs return self.tr_langs, self.te_langs, self.all_langs
def _set_datalang(self, data: pd.DataFrame): def _set_datalang(self, data: pd.DataFrame):
return {lang: data[data.lang == lang] for lang in self.all_langs} return {lang: data[data.lang == lang] for lang in self.all_langs}
def training(self, merge_validation=False, mask_number=False, target_as_csr=False): def training(self, merge_validation=False, mask_number=False, target_as_csr=False):
# TODO some additional pre-processing on the textual data?
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXtr = { lXtr = {
lang: {"text": apply_mask(self.train_data[lang].text.tolist())} # TODO inserting dict for textual data - we still have to manage visual lang: {"text": apply_mask(self.train_data[lang].text.tolist())}
for lang in self.tr_langs for lang in self.tr_langs
} }
if merge_validation: if merge_validation:
@ -103,7 +115,6 @@ class SimpleGfunDataset:
return lXtr, lYtr return lXtr, lYtr
def test(self, mask_number=False, target_as_csr=False): def test(self, mask_number=False, target_as_csr=False):
# TODO some additional pre-processing on the textual data?
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXte = { lXte = {
lang: {"text": apply_mask(self.test_data[lang].text.tolist())} lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
@ -162,62 +173,12 @@ class gFunDataset:
return mlb return mlb
def _load_dataset(self): def _load_dataset(self):
if "glami" in self.dataset_dir.lower(): print(f"- Loading dataset from {self.dataset_dir}")
print(f"- Loading GLAMI dataset from {self.dataset_dir}") self.dataset_name = "rai"
self.dataset_name = "glami" self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
self.dataset, self.labels, self.data_langs = self._load_glami( dataset_dir=self.dataset_dir,
self.dataset_dir, self.nrows nrows=self.nrows)
) self.mlb = self.get_label_binarizer(self.labels)
self.mlb = self.get_label_binarizer(self.labels)
elif "rcv" in self.dataset_dir.lower():
print(f"- Loading RCV1-2 dataset from {self.dataset_dir}")
self.dataset_name = "rcv1-2"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
elif "jrc" in self.dataset_dir.lower():
print(f"- Loading JRC dataset from {self.dataset_dir}")
self.dataset_name = "jrc"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
# WEBIS-CLS (processed)
elif (
"cls" in self.dataset_dir.lower()
and "unprocessed" not in self.dataset_dir.lower()
):
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
# WEBIS-CLS (unprocessed)
elif (
"cls" in self.dataset_dir.lower()
and "unprocessed" in self.dataset_dir.lower()
):
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
elif "rai" in self.dataset_dir.lower():
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
self.dataset_name = "rai"
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
dataset_dir="~/datasets/rai/csv/train-split-rai.csv",
nrows=self.nrows)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension() self.show_dimension()
return return
@ -232,15 +193,12 @@ class gFunDataset:
else: else:
print(f"-- Labels: {len(self.labels)}") print(f"-- Labels: {len(self.labels)}")
def _load_multilingual(self, dataset_name, dataset_dir, nrows): def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"):
if "csv" in dataset_dir: if "csv" in dataset_dir:
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv( old_dataset = MultilingualDataset(dataset_name="rai").from_csv(
# path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv", path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")),
#path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv") path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv"))
path_tr="~/datasets/rai/csv/train-split-rai.csv", )
path_te="~/datasets/rai/csv/test-split-rai.csv")
else:
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
if nrows is not None: if nrows is not None:
if dataset_name == "cls": if dataset_name == "cls":
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows) old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
@ -365,27 +323,4 @@ if __name__ == "__main__":
data_rai = SimpleGfunDataset() data_rai = SimpleGfunDataset()
lXtr, lYtr = data_rai.training(mask_number=False) lXtr, lYtr = data_rai.training(mask_number=False)
lXte, lYte = data_rai.test(mask_number=False) lXte, lYte = data_rai.test(mask_number=False)
exit() exit()
# import os
# GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
# RCV_DATAPATH = os.path.expanduser(
# "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
# )
# JRC_DATAPATH = os.path.expanduser(
# "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
# )
# print("Hello gFunDataset")
# dataset = gFunDataset(
# dataset_dir=JRC_DATAPATH,
# data_langs=None,
# is_textual=True,
# is_visual=True,
# is_multilabel=False,
# labels=None,
# nrows=13,
# )
# lXtr, lYtr = dataset.training()
# lXte, lYte = dataset.test()
# exit(0)

View File

@ -1,7 +1,5 @@
from os.path import expanduser, join from os.path import expanduser, join
from dataManager.gFunDataset import gFunDataset, SimpleGfunDataset from dataManager.gFunDataset import SimpleGfunDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset
def load_from_pickle(path, dataset_name, nrows): def load_from_pickle(path, dataset_name, nrows):
@ -16,119 +14,14 @@ def load_from_pickle(path, dataset_name, nrows):
return loaded return loaded
def get_dataset(dataset_name, args): def get_dataset(datasetp_path, args):
assert dataset_name in [ dataset = SimpleGfunDataset(
"multinews", dataset_name="rai",
"amazon", datadir=datasetp_path,
"rcv1-2", textual=True,
"glami", visual=False,
"cls", multilabel=False,
"webis", set_tr_langs=args.tr_langs,
"rai", set_te_langs=args.te_langs
], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
) )
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
WEBIS_CLS = expanduser(
# "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
"~/datasets/cls-acl10-unprocessed/csv"
)
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
if dataset_name == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif dataset_name == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif dataset_name == "jrc":
dataset = gFunDataset(
dataset_dir=JRC_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif dataset_name == "rcv1-2":
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif dataset_name == "glami":
if args.save_dataset is False:
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
else:
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
dataset.save_as_pickle(GLAMI_DATAPATH)
elif dataset_name == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
elif dataset_name == "webis":
dataset = SimpleGfunDataset(
datadir=WEBIS_CLS,
textual=True,
visual=False,
multilabel=False,
set_tr_langs=args.tr_langs,
set_te_langs=args.te_langs
)
# dataset = gFunDataset(
# dataset_dir=WEBIS_CLS,
# is_textual=True,
# is_visual=False,
# is_multilabel=False,
# nrows=args.nrows,
# )
elif dataset_name == "rai":
dataset = SimpleGfunDataset(
datadir="~/datasets/rai/csv",
textual=True,
visual=False,
multilabel=False
)
# dataset = gFunDataset(
# dataset_dir=RAI_DATAPATH,
# is_textual=True,
# is_visual=False,
# is_multilabel=False,
# nrows=args.nrows
# )
else:
raise NotImplementedError
return dataset return dataset

View File

@ -8,7 +8,6 @@ from gfun.vgfs.learners.svms import MetaClassifier, get_learner
from gfun.vgfs.multilingualGen import MultilingualGen from gfun.vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from gfun.vgfs.vanillaFun import VanillaFunGen from gfun.vgfs.vanillaFun import VanillaFunGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.wceGen import WceGen from gfun.vgfs.wceGen import WceGen
@ -19,7 +18,6 @@ class GeneralizedFunnelling:
wce, wce,
multilingual, multilingual,
textual_transformer, textual_transformer,
visual_transformer,
langs, langs,
num_labels, num_labels,
classification_type, classification_type,
@ -29,12 +27,9 @@ class GeneralizedFunnelling:
eval_batch_size, eval_batch_size,
max_length, max_length,
textual_lr, textual_lr,
visual_lr,
epochs, epochs,
patience, patience,
evaluate_step, evaluate_step,
textual_transformer_name,
visual_transformer_name,
optimc, optimc,
device, device,
load_trained, load_trained,
@ -42,13 +37,14 @@ class GeneralizedFunnelling:
probabilistic, probabilistic,
aggfunc, aggfunc,
load_meta, load_meta,
trained_text_trf=None,
textual_transformer_name=None,
): ):
# Setting VFGs ----------- # Setting VFGs -----------
self.posteriors_vgf = posterior self.posteriors_vgf = posterior
self.wce_vgf = wce self.wce_vgf = wce
self.multilingual_vgf = multilingual self.multilingual_vgf = multilingual
self.textual_trf_vgf = textual_transformer self.textual_trf_vgf = textual_transformer
self.visual_trf_vgf = visual_transformer
self.probabilistic = probabilistic self.probabilistic = probabilistic
self.num_labels = num_labels self.num_labels = num_labels
self.clf_type = classification_type self.clf_type = classification_type
@ -58,6 +54,7 @@ class GeneralizedFunnelling:
self.cached = True self.cached = True
# Textual Transformer VGF params ---------- # Textual Transformer VGF params ----------
self.textual_trf_name = textual_transformer_name self.textual_trf_name = textual_transformer_name
self.trained_text_trf = trained_text_trf
self.epochs = epochs self.epochs = epochs
self.textual_trf_lr = textual_lr self.textual_trf_lr = textual_lr
self.textual_scheduler = "ReduceLROnPlateau" self.textual_scheduler = "ReduceLROnPlateau"
@ -68,10 +65,6 @@ class GeneralizedFunnelling:
self.patience = patience self.patience = patience
self.evaluate_step = evaluate_step self.evaluate_step = evaluate_step
self.device = device self.device = device
# Visual Transformer VGF params ----------
self.visual_trf_name = visual_transformer_name
self.visual_trf_lr = visual_lr
self.visual_scheduler = "ReduceLROnPlateau"
# Metaclassifier params ------------ # Metaclassifier params ------------
self.optimc = optimc self.optimc = optimc
# ------------------- # -------------------
@ -131,7 +124,6 @@ class GeneralizedFunnelling:
self.multilingual_vgf, self.multilingual_vgf,
self.wce_vgf, self.wce_vgf,
self.textual_trf_vgf, self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc, self.aggfunc,
) )
return self return self
@ -174,26 +166,10 @@ class GeneralizedFunnelling:
patience=self.patience, patience=self.patience,
device=self.device, device=self.device,
classification_type=self.clf_type, classification_type=self.clf_type,
saved_model=self.trained_text_trf,
) )
self.first_tier_learners.append(transformer_vgf) self.first_tier_learners.append(transformer_vgf)
if self.visual_trf_vgf:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
lr=self.visual_trf_lr,
scheduler=self.visual_scheduler,
epochs=self.epochs,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
)
self.first_tier_learners.append(visual_trasformer_vgf)
if "attn" in self.aggfunc: if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1] attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator( self.attn_aggregator = AttentionAggregator(
@ -219,7 +195,6 @@ class GeneralizedFunnelling:
self.multilingual_vgf, self.multilingual_vgf,
self.wce_vgf, self.wce_vgf,
self.textual_trf_vgf, self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc, self.aggfunc,
) )
print(f"- model id: {self._model_id}") print(f"- model id: {self._model_id}")
@ -309,16 +284,29 @@ class GeneralizedFunnelling:
return aggregated return aggregated
def _aggregate_mean(self, first_tier_projections): def _aggregate_mean(self, first_tier_projections):
aggregated = { # aggregated = {
lang: np.zeros(data.shape) # lang: np.zeros(data.shape)
for lang, data in first_tier_projections[0].items() # for lang, data in first_tier_projections[0].items()
} # }
aggregated = {}
for lang_projections in first_tier_projections: for lang_projections in first_tier_projections:
for lang, projection in lang_projections.items(): for lang, projection in lang_projections.items():
if lang not in aggregated:
aggregated[lang] = np.zeros(projection.shape)
aggregated[lang] += projection aggregated[lang] += projection
for lang, projection in aggregated.items(): def get_denom(lang, projs):
aggregated[lang] /= len(first_tier_projections) den = 0
for proj in projs:
if lang in proj:
den += 1
return den
for lang, _ in aggregated.items():
# aggregated[lang] /= len(first_tier_projections)
aggregated[lang] /= get_denom(lang, first_tier_projections)
return aggregated return aggregated
@ -416,18 +404,6 @@ class GeneralizedFunnelling:
) as vgf: ) as vgf:
first_tier_learners.append(pickle.load(vgf)) first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained Textual Transformer VGF") print(f"- loaded trained Textual Transformer VGF")
if self.visual_trf_vgf:
with open(
os.path.join(
"models",
"vgfs",
"visual_transformer",
f"visualTransformerGen_{model_id}.pkl",
),
"rb",
print(f"- loaded trained Visual Transformer VGF"),
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if load_meta: if load_meta:
with open( with open(
@ -482,7 +458,6 @@ def get_unique_id(
multilingual, multilingual,
wce, wce,
textual_transformer, textual_transformer,
visual_transformer,
aggfunc, aggfunc,
): ):
from datetime import datetime from datetime import datetime
@ -493,6 +468,5 @@ def get_unique_id(
model_id += "m" if multilingual else "" model_id += "m" if multilingual else ""
model_id += "w" if wce else "" model_id += "w" if wce else ""
model_id += "t" if textual_transformer else "" model_id += "t" if textual_transformer else ""
model_id += "v" if visual_transformer else ""
model_id += f"_{aggfunc}" model_id += f"_{aggfunc}"
return f"{model_id}_{now}" return f"{model_id}_{now}"

View File

@ -173,9 +173,18 @@ class NaivePolylingualClassifier:
:return: a dictionary of probabilities that each document belongs to each class :return: a dictionary of probabilities that each document belongs to each class
""" """
assert self.model is not None, "predict called before fit" assert self.model is not None, "predict called before fit"
assert set(lX.keys()).issubset( if not set(lX.keys()).issubset(set(self.model.keys())):
set(self.model.keys()) langs = set(lX.keys()).intersection(set(self.model.keys()))
), "unknown languages requested in decision function" scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs)
# res = {lang: None for lang in lX.keys()}
# for i, lang in enumerate(langs):
# res[lang] = scores[i]
# return res
return {lang: scores[i] for i, lang in enumerate(langs)}
# assert set(lX.keys()).issubset(
# set(self.model.keys())
# ), "unknown languages requested in decision function"
langs = list(lX.keys()) langs = list(lX.keys())
scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)( scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs

View File

@ -73,6 +73,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
patience=5, patience=5,
classification_type="multilabel", classification_type="multilabel",
scheduler="ReduceLROnPlateau", scheduler="ReduceLROnPlateau",
saved_model = None
): ):
super().__init__( super().__init__(
self._validate_model_name(model_name), self._validate_model_name(model_name),
@ -91,6 +92,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
n_jobs=n_jobs, n_jobs=n_jobs,
verbose=verbose, verbose=verbose,
) )
self.saved_model = saved_model
self.clf_type = classification_type self.clf_type = classification_type
self.fitted = False self.fitted = False
print( print(
@ -109,26 +111,25 @@ class TextualTransformerGen(ViewGen, TransformerGen):
else: else:
raise NotImplementedError raise NotImplementedError
def load_pretrained_model(self, model_name, num_labels): def load_pretrained_model(self, model_name, num_labels, saved_model=None):
if model_name == "google/mt5-small": if model_name == "google/mt5-small":
return MT5ForSequenceClassification( return MT5ForSequenceClassification(
model_name, num_labels=num_labels, output_hidden_states=True model_name, num_labels=num_labels, output_hidden_states=True
) )
else: else:
# model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert if saved_model:
# model_name = "hf_models/mbert-rai-fewshot-second/checkpoint-19000" # TODO hardcoded to pre-traiend mbert model_name = saved_model
# model_name = "hf_models/mbert-sentiment/checkpoint-1150" # TODO hardcoded to pre-traiend mbert else:
model_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600" model_name = "google/bert-base-multilingual-cased"
return AutoModelForSequenceClassification.from_pretrained( return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True model_name, num_labels=num_labels, output_hidden_states=True
) )
def load_tokenizer(self, model_name): def load_tokenizer(self, model_name):
# model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
return AutoTokenizer.from_pretrained(model_name) return AutoTokenizer.from_pretrained(model_name)
def init_model(self, model_name, num_labels): def init_model(self, model_name, num_labels, saved_model):
return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer( return self.load_pretrained_model(model_name, num_labels, saved_model), self.load_tokenizer(
model_name model_name
) )
@ -148,64 +149,11 @@ class TextualTransformerGen(ViewGen, TransformerGen):
_l = list(lX.keys())[0] _l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1] self.num_labels = lY[_l].shape[-1]
self.model, self.tokenizer = self.init_model( self.model, self.tokenizer = self.init_model(
self.model_name, num_labels=self.num_labels self.model_name, num_labels=self.num_labels, saved_model=self.saved_model,
) )
self.model.to("cuda") self.model.to("cuda")
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
# lX, lY, split=0.2, seed=42, modality="text"
# )
#
# tra_dataloader = self.build_dataloader(
# tr_lX,
# tr_lY,
# processor_fn=self._tokenize,
# torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size,
# split="train",
# shuffle=True,
# )
#
# val_dataloader = self.build_dataloader(
# val_lX,
# val_lY,
# processor_fn=self._tokenize,
# torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size_eval,
# split="val",
# shuffle=False,
# )
#
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
#
# trainer = Trainer(
# model=self.model,
# optimizer_name="adamW",
# lr=self.lr,
# device=self.device,
# loss_fn=torch.nn.CrossEntropyLoss(),
# print_steps=self.print_steps,
# evaluate_step=self.evaluate_step,
# patience=self.patience,
# experiment_name=experiment_name,
# checkpoint_path=os.path.join(
# "models",
# "vgfs",
# "trained_transformer",
# self._format_model_name(self.model_name),
# ),
# vgf_name="textual_trf",
# classification_type=self.clf_type,
# n_jobs=self.n_jobs,
# scheduler_name=self.scheduler,
# )
# trainer.train(
# train_dataloader=tra_dataloader,
# eval_dataloader=val_dataloader,
# epochs=self.epochs,
# )
if self.probabilistic: if self.probabilistic:
transformed = self.transform(lX) transformed = self.transform(lX)
self.feature2posterior_projector.fit(transformed, lY) self.feature2posterior_projector.fit(transformed, lY)

View File

@ -1,4 +1,4 @@
from os.path import expanduser from os.path import expanduser, join
import torch import torch
from transformers import ( from transformers import (
@ -19,13 +19,8 @@ import pandas as pd
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"] RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"]
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] # "lang"
WEBIS_D_COLUMNS = ['Unnamed: 0', 'asin', 'category', 'original_rating', 'label', 'title', 'text', 'summary'] # "lang"
MAX_LEN = 128 MAX_LEN = 128
# DATASET_NAME = "rai"
# DATASET_NAME = "rai-multilingual-2000"
# DATASET_NAME = "webis-cls"
def init_callbacks(patience=-1, nosave=False): def init_callbacks(patience=-1, nosave=False):
@ -35,13 +30,17 @@ def init_callbacks(patience=-1, nosave=False):
return callbacks return callbacks
def init_model(model_name, nlabels): def init_model(model_name, nlabels, saved_model=None):
if model_name == "mbert": if model_name == "mbert":
# hf_name = "bert-base-multilingual-cased" if saved_model is None:
hf_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600" hf_name = "bert-base-multilingual-cased"
# hf_name = "hf_models/mbert-rai-fewshot-second/checkpoint-9000" else:
hf_name = saved_model
elif model_name == "xlm-roberta": elif model_name == "xlm-roberta":
hf_name = "xlm-roberta-base" if saved_model is None:
hf_name = "xlm-roberta-base"
else:
hf_name = saved_model
else: else:
raise NotImplementedError raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained(hf_name) tokenizer = AutoTokenizer.from_pretrained(hf_name)
@ -50,43 +49,41 @@ def init_model(model_name, nlabels):
def main(args): def main(args):
tokenizer, model = init_model(args.model, args.nlabels) saved_model = args.savedmodel
trainlang = args.trainlangs
datapath = args.datapath
tokenizer, model = init_model(args.model, args.nlabels, saved_model=saved_model)
data = load_dataset( data = load_dataset(
"csv", "csv",
data_files = { data_files = {
"train": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/train.balanced.csv"), "train": expanduser(join(datapath, "train.csv")),
"test": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/test.balanced.csv") "test": expanduser(join(datapath, "test.small.csv"))
# "train": expanduser(f"~/datasets/rai/csv/train-{DATASET_NAME}.csv"),
# "test": expanduser(f"~/datasets/rai/csv/test-{DATASET_NAME}.csv")
# "train": expanduser(f"~/datasets/rai/csv/train.small.csv"),
# "test": expanduser(f"~/datasets/rai/csv/test.small.csv")
} }
) )
def filter_dataset(dataset, lang):
indices = [i for i, l in enumerate(dataset["lang"]) if l == lang]
dataset = dataset.select(indices)
return dataset
if trainlang is not None:
data["train"] = filter_dataset(data["train"], lang=trainlang)
def process_sample_rai(sample): def process_sample_rai(sample):
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])] inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
labels = sample["label"] labels = sample["label"]
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there... model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True)
model_inputs["labels"] = labels model_inputs["labels"] = labels
return model_inputs return model_inputs
def process_sample_webis(sample):
inputs = sample["text"]
labels = sample["label"]
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
model_inputs["labels"] = labels
return model_inputs
data = data.map( data = data.map(
# process_sample_rai, process_sample_rai,
process_sample_webis,
batched=True, batched=True,
num_proc=4, num_proc=4,
load_from_cache_file=True, load_from_cache_file=True,
# remove_columns=RAI_D_COLUMNS, remove_columns=RAI_D_COLUMNS,
remove_columns=WEBIS_D_COLUMNS,
) )
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42) train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
data.set_format("torch") data.set_format("torch")
@ -107,8 +104,8 @@ def main(args):
recall_metric = evaluate.load("recall") recall_metric = evaluate.load("recall")
training_args = TrainingArguments( training_args = TrainingArguments(
# output_dir=f"hf_models/{args.model}-rai", output_dir=f"hf_models/{args.model}-fewshot-full" if trainlang is None else f"hf_models/{args.model}-zeroshot-full",
output_dir=f"hf_models/{args.model}-sentiment-balanced", run_name="model-zeroshot" if trainlang is not None else "model-fewshot",
do_train=True, do_train=True,
evaluation_strategy="steps", evaluation_strategy="steps",
per_device_train_batch_size=args.batch, per_device_train_batch_size=args.batch,
@ -130,8 +127,6 @@ def main(args):
save_strategy="no" if args.nosave else "steps", save_strategy="no" if args.nosave else "steps",
save_total_limit=2, save_total_limit=2,
eval_steps=args.stepeval, eval_steps=args.stepeval,
# run_name=f"{args.model}-rai-stratified",
run_name=f"{args.model}-sentiment",
disable_tqdm=False, disable_tqdm=False,
log_level="warning", log_level="warning",
report_to=["wandb"] if args.wandb else "none", report_to=["wandb"] if args.wandb else "none",
@ -142,7 +137,6 @@ def main(args):
def compute_metrics(eval_preds): def compute_metrics(eval_preds):
preds = eval_preds.predictions.argmax(-1) preds = eval_preds.predictions.argmax(-1)
# targets = eval_preds.label_ids.argmax(-1)
targets = eval_preds.label_ids targets = eval_preds.label_ids
setting = "macro" setting = "macro"
f1_score_macro = f1_metric.compute( f1_score_macro = f1_metric.compute(
@ -170,7 +164,9 @@ def main(args):
if args.wandb: if args.wandb:
import wandb import wandb
wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-rai", config=vars(args)) wandb.init(entity="andreapdr", project=f"gfun",
name="model-zeroshot-full" if trainlang is not None else "model-fewshot-full",
config=vars(args))
trainer = Trainer( trainer = Trainer(
model=model, model=model,
@ -188,17 +184,21 @@ def main(args):
trainer.train() trainer.train()
print("- Testing:") print("- Testing:")
test_results = trainer.evaluate(eval_dataset=data["test"], metric_key_prefix="test")
test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test") test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test")
pprint(test_results.metrics) pprint(test_results.metrics)
save_preds(data["test"], test_results.predictions) save_preds(data["test"], test_results.predictions, trainlang)
exit() exit()
def save_preds(dataset, predictions): def save_preds(dataset, predictions, trainlang=None):
df = pd.DataFrame() df = pd.DataFrame()
df["langs"] = dataset["lang"] df["langs"] = dataset["lang"]
df["labels"] = dataset["labels"] df["labels"] = dataset["labels"]
df["preds"] = predictions.argmax(axis=1) df["preds"] = predictions.argmax(axis=1)
df.to_csv("results/lang-specific.mbert.webis.csv", index=False) if trainlang is not None:
df.to_csv(f"results/zeroshot.{trainlang}.model.csv", index=False)
else:
df.to_csv("results/fewshot.model.csv", index=False)
return return
@ -210,16 +210,18 @@ if __name__ == "__main__":
parser.add_argument("--nlabels", type=int, metavar="", default=28) parser.add_argument("--nlabels", type=int, metavar="", default=28)
parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",) parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",)
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]") parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size") parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps") parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs") parser.add_argument("--epochs", type=int, metavar="", default=10, help="Set epochs")
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps") parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps") parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience") parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision") parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
parser.add_argument("--wandb", action="store_true", help="Log to wandb") parser.add_argument("--wandb", action="store_true", help="Log to wandb")
parser.add_argument("--nosave", action="store_true", help="Avoid saving model") parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set") parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data") parser.add_argument("--trainlang", default=None, type=str, help="set training language for zero-shot experiments" )
parser.add_argument("--datapath", type=str, default="data", help="path to the csv dataset. Dir should contain both a train.csv and a test.csv file")
parser.add_argument("--savedmodel", type=str, default="hf_models/mbert-rai-fewshot-second/checkpoint-9000")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

30
main.py
View File

@ -10,8 +10,6 @@ import pandas as pd
""" """
TODO: TODO:
- General:
[!] zero-shot setup
- Docs: - Docs:
- add documentations sphinx - add documentations sphinx
""" """
@ -27,13 +25,11 @@ def get_config_name(args):
config_name += "M+" config_name += "M+"
if args.textual_transformer: if args.textual_transformer:
config_name += f"TT_{args.textual_trf_name}+" config_name += f"TT_{args.textual_trf_name}+"
if args.visual_transformer:
config_name += f"VT_{args.visual_trf_name}+"
return config_name.rstrip("+") return config_name.rstrip("+")
def main(args): def main(args):
dataset = get_dataset(args.dataset, args) dataset = get_dataset(args.datadir, args)
lX, lY = dataset.training(merge_validation=True) lX, lY = dataset.training(merge_validation=True)
lX_te, lY_te = dataset.test() lX_te, lY_te = dataset.test()
@ -47,13 +43,12 @@ def main(args):
args.multilingual, args.multilingual,
args.multilingual, args.multilingual,
args.textual_transformer, args.textual_transformer,
args.visual_transformer,
] ]
), "At least one of VGF must be True" ), "At least one of VGF must be True"
gfun = GeneralizedFunnelling( gfun = GeneralizedFunnelling(
# dataset params ---------------------- # dataset params ----------------------
dataset_name=args.dataset, dataset_name=dataset,
langs=dataset.langs(), langs=dataset.langs(),
num_labels=dataset.num_labels(), num_labels=dataset.num_labels(),
classification_type=args.clf_type, classification_type=args.clf_type,
@ -67,24 +62,15 @@ def main(args):
# Transformer VGF params -------------- # Transformer VGF params --------------
textual_transformer=args.textual_transformer, textual_transformer=args.textual_transformer,
textual_transformer_name=args.textual_trf_name, textual_transformer_name=args.textual_trf_name,
trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
batch_size=args.batch_size, batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size, eval_batch_size=args.eval_batch_size,
epochs=args.epochs, epochs=args.epochs,
textual_lr=args.textual_lr, textual_lr=args.textual_lr,
visual_lr=args.visual_lr,
max_length=args.max_length, max_length=args.max_length,
patience=args.patience, patience=args.patience,
evaluate_step=args.evaluate_step, evaluate_step=args.evaluate_step,
device=args.device, device=args.device,
# Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_trf_name,
# batch_size=args.batch_size,
# epochs=args.epochs,
# lr=args.lr,
# patience=args.patience,
# evaluate_step=args.evaluate_step,
# device="cuda",
# General params --------------------- # General params ---------------------
probabilistic=args.features, probabilistic=args.features,
aggfunc=args.aggfunc, aggfunc=args.aggfunc,
@ -145,7 +131,7 @@ def main(args):
log_barplot_wandb(lang_metrics_gfun, title_affix="per language") log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
config["gFun"]["timing"] = f"{timeval - tinit:.2f}" config["gFun"]["timing"] = f"{timeval - tinit:.2f}"
csvlogger = CsvLogger(outfile="results/random.log.csv").log_lang_results(lang_metrics_gfun, config) csvlogger = CsvLogger(outfile="results/gfun.log.csv").log_lang_results(lang_metrics_gfun, config, notes="")
save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"]) save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"])
@ -162,7 +148,7 @@ def save_preds(preds, targets, config="unk", dataset="unk"):
df["langs"] = _langs df["langs"] = _langs
df["labels"] = _targets df["labels"] = _targets
df["preds"] = _preds df["preds"] = _preds
df.to_csv(f"results/lang-specific.gfun.{config}.{dataset}.csv", index=False) df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False)
if __name__ == "__main__": if __name__ == "__main__":
@ -174,7 +160,7 @@ if __name__ == "__main__":
parser.add_argument("--tr_langs", nargs="+", default=None) parser.add_argument("--tr_langs", nargs="+", default=None)
parser.add_argument("--te_langs", nargs="+", default=None) parser.add_argument("--te_langs", nargs="+", default=None)
# Dataset parameters ------------------- # Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2") parser.add_argument("-d", "--datadir", type=str, default=None, help="dir to dataset. It should contain both a train.csv and a test.csv file")
parser.add_argument("--domains", type=str, default="all") parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--min_count", type=int, default=10)
@ -186,7 +172,6 @@ if __name__ == "__main__":
parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true")
parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--textual_transformer", action="store_true") parser.add_argument("-t", "--textual_transformer", action="store_true")
parser.add_argument("-v", "--visual_transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=-1) parser.add_argument("--n_jobs", type=int, default=-1)
parser.add_argument("--optimc", action="store_true") parser.add_argument("--optimc", action="store_true")
parser.add_argument("--features", action="store_false") parser.add_argument("--features", action="store_false")
@ -200,9 +185,6 @@ if __name__ == "__main__":
parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5) parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10) parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters --------------
parser.add_argument("--visual_trf_name", type=str, default="vit")
parser.add_argument("--visual_lr", type=float, default=1e-4)
# logging # logging
parser.add_argument("--wandb", action="store_true") parser.add_argument("--wandb", action="store_true")