branching for rai
This commit is contained in:
parent
fbd740fabd
commit
234b6031b1
|
|
@ -5,12 +5,6 @@ from sklearn.metrics import mean_absolute_error
|
||||||
|
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
"""
|
|
||||||
MEA and classification is meaningful only in "ordinal" tasks e.g., sentiment classification.
|
|
||||||
Otherwise the distance between the categories has no semantics!
|
|
||||||
|
|
||||||
- NB: we want to get the macro-averaged class specific MAE!
|
|
||||||
"""
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"]
|
# SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"]
|
||||||
|
|
@ -24,11 +18,9 @@ def main():
|
||||||
print(df)
|
print(df)
|
||||||
|
|
||||||
|
|
||||||
def evalaute(setting):
|
def evalaute(setting, result_file="preds.gfun.csv"):
|
||||||
result_dir = "results"
|
result_dir = "results"
|
||||||
# result_file = f"lang-specific.gfun.{setting}.webis.csv"
|
print(f"- reading from: {result_file}")
|
||||||
result_file = f"lang-specific.mbert.webis.csv"
|
|
||||||
# print(f"- reading from: {result_file}")
|
|
||||||
df = pd.read_csv(join(result_dir, result_file))
|
df = pd.read_csv(join(result_dir, result_file))
|
||||||
langs = df.langs.unique()
|
langs = df.langs.unique()
|
||||||
res = []
|
res = []
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ class CsvLogger:
|
||||||
# os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True)
|
# os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True)
|
||||||
# return
|
# return
|
||||||
|
|
||||||
def log_lang_results(self, results: dict, config="gfun-lello"):
|
def log_lang_results(self, results: dict, config="gfun-default", notes=None):
|
||||||
df = pd.DataFrame.from_dict(results, orient="columns")
|
df = pd.DataFrame.from_dict(results, orient="columns")
|
||||||
df["config"] = config["gFun"]["simple_id"]
|
df["config"] = config["gFun"]["simple_id"]
|
||||||
df["aggfunc"] = config["gFun"]["aggfunc"]
|
df["aggfunc"] = config["gFun"]["aggfunc"]
|
||||||
|
|
@ -22,6 +22,7 @@ class CsvLogger:
|
||||||
df["id"] = config["gFun"]["id"]
|
df["id"] = config["gFun"]["id"]
|
||||||
df["optimc"] = config["gFun"]["optimc"]
|
df["optimc"] = config["gFun"]["optimc"]
|
||||||
df["timing"] = config["gFun"]["timing"]
|
df["timing"] = config["gFun"]["timing"]
|
||||||
|
df["notes"] = notes
|
||||||
with open(self.outfile, 'a') as f:
|
with open(self.outfile, 'a') as f:
|
||||||
df.to_csv(f, mode='a', header=f.tell()==0)
|
df.to_csv(f, mode='a', header=f.tell()==0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,17 @@ from dataManager.multilingualDataset import MultilingualDataset
|
||||||
|
|
||||||
|
|
||||||
class SimpleGfunDataset:
|
class SimpleGfunDataset:
|
||||||
def __init__(self, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_name=None,
|
||||||
|
datadir="~/datasets/rai/csv/",
|
||||||
|
textual=True,
|
||||||
|
visual=False,
|
||||||
|
multilabel=False,
|
||||||
|
set_tr_langs=None,
|
||||||
|
set_te_langs=None
|
||||||
|
):
|
||||||
|
self.name = dataset_name
|
||||||
self.datadir = os.path.expanduser(datadir)
|
self.datadir = os.path.expanduser(datadir)
|
||||||
self.textual = textual
|
self.textual = textual
|
||||||
self.visual = visual
|
self.visual = visual
|
||||||
|
|
@ -41,11 +51,15 @@ class SimpleGfunDataset:
|
||||||
print(f"tr: {tr} - va: {va} - te: {te}")
|
print(f"tr: {tr} - va: {va} - te: {te}")
|
||||||
|
|
||||||
def load_csv(self, set_tr_langs, set_te_langs):
|
def load_csv(self, set_tr_langs, set_te_langs):
|
||||||
# _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
|
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
|
||||||
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.balanced.csv")).sample(100, random_state=42)
|
try:
|
||||||
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) # TODO stratify on lang or label?
|
stratified = "class"
|
||||||
# test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
|
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label)
|
||||||
test = pd.read_csv(os.path.join(self.datadir, "test.balanced.csv")).sample(100, random_state=42)
|
except:
|
||||||
|
stratified = "lang"
|
||||||
|
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang)
|
||||||
|
print(f"- dataset stratified by {stratified}")
|
||||||
|
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
|
||||||
self._set_langs (train, test, set_tr_langs, set_te_langs)
|
self._set_langs (train, test, set_tr_langs, set_te_langs)
|
||||||
self._set_labels(_data_tr)
|
self._set_labels(_data_tr)
|
||||||
self.full_train = _data_tr
|
self.full_train = _data_tr
|
||||||
|
|
@ -56,8 +70,6 @@ class SimpleGfunDataset:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _set_labels(self, data):
|
def _set_labels(self, data):
|
||||||
# self.labels = [i for i in range(28)] # todo hard-coded for rai
|
|
||||||
# self.labels = [i for i in range(3)] # TODO hard coded for sentimnet
|
|
||||||
self.labels = sorted(list(data.label.unique()))
|
self.labels = sorted(list(data.label.unique()))
|
||||||
|
|
||||||
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
|
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
|
||||||
|
|
@ -73,14 +85,14 @@ class SimpleGfunDataset:
|
||||||
|
|
||||||
return self.tr_langs, self.te_langs, self.all_langs
|
return self.tr_langs, self.te_langs, self.all_langs
|
||||||
|
|
||||||
|
|
||||||
def _set_datalang(self, data: pd.DataFrame):
|
def _set_datalang(self, data: pd.DataFrame):
|
||||||
return {lang: data[data.lang == lang] for lang in self.all_langs}
|
return {lang: data[data.lang == lang] for lang in self.all_langs}
|
||||||
|
|
||||||
def training(self, merge_validation=False, mask_number=False, target_as_csr=False):
|
def training(self, merge_validation=False, mask_number=False, target_as_csr=False):
|
||||||
# TODO some additional pre-processing on the textual data?
|
|
||||||
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
||||||
lXtr = {
|
lXtr = {
|
||||||
lang: {"text": apply_mask(self.train_data[lang].text.tolist())} # TODO inserting dict for textual data - we still have to manage visual
|
lang: {"text": apply_mask(self.train_data[lang].text.tolist())}
|
||||||
for lang in self.tr_langs
|
for lang in self.tr_langs
|
||||||
}
|
}
|
||||||
if merge_validation:
|
if merge_validation:
|
||||||
|
|
@ -103,7 +115,6 @@ class SimpleGfunDataset:
|
||||||
return lXtr, lYtr
|
return lXtr, lYtr
|
||||||
|
|
||||||
def test(self, mask_number=False, target_as_csr=False):
|
def test(self, mask_number=False, target_as_csr=False):
|
||||||
# TODO some additional pre-processing on the textual data?
|
|
||||||
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
||||||
lXte = {
|
lXte = {
|
||||||
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
||||||
|
|
@ -162,62 +173,12 @@ class gFunDataset:
|
||||||
return mlb
|
return mlb
|
||||||
|
|
||||||
def _load_dataset(self):
|
def _load_dataset(self):
|
||||||
if "glami" in self.dataset_dir.lower():
|
print(f"- Loading dataset from {self.dataset_dir}")
|
||||||
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
|
||||||
self.dataset_name = "glami"
|
|
||||||
self.dataset, self.labels, self.data_langs = self._load_glami(
|
|
||||||
self.dataset_dir, self.nrows
|
|
||||||
)
|
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
|
||||||
|
|
||||||
elif "rcv" in self.dataset_dir.lower():
|
|
||||||
print(f"- Loading RCV1-2 dataset from {self.dataset_dir}")
|
|
||||||
self.dataset_name = "rcv1-2"
|
|
||||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
|
||||||
self.dataset_name, self.dataset_dir, self.nrows
|
|
||||||
)
|
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
|
||||||
|
|
||||||
elif "jrc" in self.dataset_dir.lower():
|
|
||||||
print(f"- Loading JRC dataset from {self.dataset_dir}")
|
|
||||||
self.dataset_name = "jrc"
|
|
||||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
|
||||||
self.dataset_name, self.dataset_dir, self.nrows
|
|
||||||
)
|
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
|
||||||
|
|
||||||
# WEBIS-CLS (processed)
|
|
||||||
elif (
|
|
||||||
"cls" in self.dataset_dir.lower()
|
|
||||||
and "unprocessed" not in self.dataset_dir.lower()
|
|
||||||
):
|
|
||||||
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
|
|
||||||
self.dataset_name = "cls"
|
|
||||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
|
||||||
self.dataset_name, self.dataset_dir, self.nrows
|
|
||||||
)
|
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
|
||||||
|
|
||||||
# WEBIS-CLS (unprocessed)
|
|
||||||
elif (
|
|
||||||
"cls" in self.dataset_dir.lower()
|
|
||||||
and "unprocessed" in self.dataset_dir.lower()
|
|
||||||
):
|
|
||||||
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
|
|
||||||
self.dataset_name = "cls"
|
|
||||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
|
||||||
self.dataset_name, self.dataset_dir, self.nrows
|
|
||||||
)
|
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
|
||||||
|
|
||||||
elif "rai" in self.dataset_dir.lower():
|
|
||||||
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
|
|
||||||
self.dataset_name = "rai"
|
self.dataset_name = "rai"
|
||||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
|
||||||
dataset_dir="~/datasets/rai/csv/train-split-rai.csv",
|
dataset_dir=self.dataset_dir,
|
||||||
nrows=self.nrows)
|
nrows=self.nrows)
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
self.show_dimension()
|
self.show_dimension()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -232,15 +193,12 @@ class gFunDataset:
|
||||||
else:
|
else:
|
||||||
print(f"-- Labels: {len(self.labels)}")
|
print(f"-- Labels: {len(self.labels)}")
|
||||||
|
|
||||||
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"):
|
||||||
if "csv" in dataset_dir:
|
if "csv" in dataset_dir:
|
||||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv(
|
old_dataset = MultilingualDataset(dataset_name="rai").from_csv(
|
||||||
# path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
|
path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")),
|
||||||
#path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
|
path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv"))
|
||||||
path_tr="~/datasets/rai/csv/train-split-rai.csv",
|
)
|
||||||
path_te="~/datasets/rai/csv/test-split-rai.csv")
|
|
||||||
else:
|
|
||||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
|
||||||
if nrows is not None:
|
if nrows is not None:
|
||||||
if dataset_name == "cls":
|
if dataset_name == "cls":
|
||||||
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
|
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
|
||||||
|
|
@ -366,26 +324,3 @@ if __name__ == "__main__":
|
||||||
lXtr, lYtr = data_rai.training(mask_number=False)
|
lXtr, lYtr = data_rai.training(mask_number=False)
|
||||||
lXte, lYte = data_rai.test(mask_number=False)
|
lXte, lYte = data_rai.test(mask_number=False)
|
||||||
exit()
|
exit()
|
||||||
# import os
|
|
||||||
|
|
||||||
# GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
|
||||||
# RCV_DATAPATH = os.path.expanduser(
|
|
||||||
# "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
|
||||||
# )
|
|
||||||
# JRC_DATAPATH = os.path.expanduser(
|
|
||||||
# "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# print("Hello gFunDataset")
|
|
||||||
# dataset = gFunDataset(
|
|
||||||
# dataset_dir=JRC_DATAPATH,
|
|
||||||
# data_langs=None,
|
|
||||||
# is_textual=True,
|
|
||||||
# is_visual=True,
|
|
||||||
# is_multilabel=False,
|
|
||||||
# labels=None,
|
|
||||||
# nrows=13,
|
|
||||||
# )
|
|
||||||
# lXtr, lYtr = dataset.training()
|
|
||||||
# lXte, lYte = dataset.test()
|
|
||||||
# exit(0)
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
from os.path import expanduser, join
|
from os.path import expanduser, join
|
||||||
from dataManager.gFunDataset import gFunDataset, SimpleGfunDataset
|
from dataManager.gFunDataset import SimpleGfunDataset
|
||||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
|
||||||
from dataManager.amazonDataset import AmazonDataset
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_pickle(path, dataset_name, nrows):
|
def load_from_pickle(path, dataset_name, nrows):
|
||||||
|
|
@ -16,119 +14,14 @@ def load_from_pickle(path, dataset_name, nrows):
|
||||||
return loaded
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(dataset_name, args):
|
def get_dataset(datasetp_path, args):
|
||||||
assert dataset_name in [
|
|
||||||
"multinews",
|
|
||||||
"amazon",
|
|
||||||
"rcv1-2",
|
|
||||||
"glami",
|
|
||||||
"cls",
|
|
||||||
"webis",
|
|
||||||
"rai",
|
|
||||||
], "dataset not supported"
|
|
||||||
|
|
||||||
RCV_DATAPATH = expanduser(
|
|
||||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
|
||||||
)
|
|
||||||
JRC_DATAPATH = expanduser(
|
|
||||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
|
||||||
)
|
|
||||||
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
|
||||||
|
|
||||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
|
||||||
|
|
||||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
|
||||||
|
|
||||||
WEBIS_CLS = expanduser(
|
|
||||||
# "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
|
||||||
"~/datasets/cls-acl10-unprocessed/csv"
|
|
||||||
)
|
|
||||||
|
|
||||||
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
|
|
||||||
|
|
||||||
if dataset_name == "multinews":
|
|
||||||
# TODO: convert to gFunDataset
|
|
||||||
raise NotImplementedError
|
|
||||||
dataset = MultiNewsDataset(
|
|
||||||
expanduser(MULTINEWS_DATAPATH),
|
|
||||||
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
|
||||||
)
|
|
||||||
elif dataset_name == "amazon":
|
|
||||||
# TODO: convert to gFunDataset
|
|
||||||
raise NotImplementedError
|
|
||||||
dataset = AmazonDataset(
|
|
||||||
domains=args.domains,
|
|
||||||
nrows=args.nrows,
|
|
||||||
min_count=args.min_count,
|
|
||||||
max_labels=args.max_labels,
|
|
||||||
)
|
|
||||||
elif dataset_name == "jrc":
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=JRC_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=False,
|
|
||||||
is_multilabel=True,
|
|
||||||
nrows=args.nrows,
|
|
||||||
)
|
|
||||||
elif dataset_name == "rcv1-2":
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=RCV_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=False,
|
|
||||||
is_multilabel=True,
|
|
||||||
nrows=args.nrows,
|
|
||||||
)
|
|
||||||
elif dataset_name == "glami":
|
|
||||||
if args.save_dataset is False:
|
|
||||||
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
|
|
||||||
else:
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=GLAMI_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=True,
|
|
||||||
is_multilabel=False,
|
|
||||||
nrows=args.nrows,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset.save_as_pickle(GLAMI_DATAPATH)
|
|
||||||
elif dataset_name == "cls":
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=CLS_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=False,
|
|
||||||
is_multilabel=False,
|
|
||||||
nrows=args.nrows,
|
|
||||||
)
|
|
||||||
elif dataset_name == "webis":
|
|
||||||
dataset = SimpleGfunDataset(
|
dataset = SimpleGfunDataset(
|
||||||
datadir=WEBIS_CLS,
|
dataset_name="rai",
|
||||||
|
datadir=datasetp_path,
|
||||||
textual=True,
|
textual=True,
|
||||||
visual=False,
|
visual=False,
|
||||||
multilabel=False,
|
multilabel=False,
|
||||||
set_tr_langs=args.tr_langs,
|
set_tr_langs=args.tr_langs,
|
||||||
set_te_langs=args.te_langs
|
set_te_langs=args.te_langs
|
||||||
)
|
)
|
||||||
# dataset = gFunDataset(
|
|
||||||
# dataset_dir=WEBIS_CLS,
|
|
||||||
# is_textual=True,
|
|
||||||
# is_visual=False,
|
|
||||||
# is_multilabel=False,
|
|
||||||
# nrows=args.nrows,
|
|
||||||
# )
|
|
||||||
elif dataset_name == "rai":
|
|
||||||
dataset = SimpleGfunDataset(
|
|
||||||
datadir="~/datasets/rai/csv",
|
|
||||||
textual=True,
|
|
||||||
visual=False,
|
|
||||||
multilabel=False
|
|
||||||
)
|
|
||||||
# dataset = gFunDataset(
|
|
||||||
# dataset_dir=RAI_DATAPATH,
|
|
||||||
# is_textual=True,
|
|
||||||
# is_visual=False,
|
|
||||||
# is_multilabel=False,
|
|
||||||
# nrows=args.nrows
|
|
||||||
# )
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
||||||
from gfun.vgfs.multilingualGen import MultilingualGen
|
from gfun.vgfs.multilingualGen import MultilingualGen
|
||||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||||
from gfun.vgfs.vanillaFun import VanillaFunGen
|
from gfun.vgfs.vanillaFun import VanillaFunGen
|
||||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
|
||||||
from gfun.vgfs.wceGen import WceGen
|
from gfun.vgfs.wceGen import WceGen
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,7 +18,6 @@ class GeneralizedFunnelling:
|
||||||
wce,
|
wce,
|
||||||
multilingual,
|
multilingual,
|
||||||
textual_transformer,
|
textual_transformer,
|
||||||
visual_transformer,
|
|
||||||
langs,
|
langs,
|
||||||
num_labels,
|
num_labels,
|
||||||
classification_type,
|
classification_type,
|
||||||
|
|
@ -29,12 +27,9 @@ class GeneralizedFunnelling:
|
||||||
eval_batch_size,
|
eval_batch_size,
|
||||||
max_length,
|
max_length,
|
||||||
textual_lr,
|
textual_lr,
|
||||||
visual_lr,
|
|
||||||
epochs,
|
epochs,
|
||||||
patience,
|
patience,
|
||||||
evaluate_step,
|
evaluate_step,
|
||||||
textual_transformer_name,
|
|
||||||
visual_transformer_name,
|
|
||||||
optimc,
|
optimc,
|
||||||
device,
|
device,
|
||||||
load_trained,
|
load_trained,
|
||||||
|
|
@ -42,13 +37,14 @@ class GeneralizedFunnelling:
|
||||||
probabilistic,
|
probabilistic,
|
||||||
aggfunc,
|
aggfunc,
|
||||||
load_meta,
|
load_meta,
|
||||||
|
trained_text_trf=None,
|
||||||
|
textual_transformer_name=None,
|
||||||
):
|
):
|
||||||
# Setting VFGs -----------
|
# Setting VFGs -----------
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
self.wce_vgf = wce
|
self.wce_vgf = wce
|
||||||
self.multilingual_vgf = multilingual
|
self.multilingual_vgf = multilingual
|
||||||
self.textual_trf_vgf = textual_transformer
|
self.textual_trf_vgf = textual_transformer
|
||||||
self.visual_trf_vgf = visual_transformer
|
|
||||||
self.probabilistic = probabilistic
|
self.probabilistic = probabilistic
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.clf_type = classification_type
|
self.clf_type = classification_type
|
||||||
|
|
@ -58,6 +54,7 @@ class GeneralizedFunnelling:
|
||||||
self.cached = True
|
self.cached = True
|
||||||
# Textual Transformer VGF params ----------
|
# Textual Transformer VGF params ----------
|
||||||
self.textual_trf_name = textual_transformer_name
|
self.textual_trf_name = textual_transformer_name
|
||||||
|
self.trained_text_trf = trained_text_trf
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.textual_trf_lr = textual_lr
|
self.textual_trf_lr = textual_lr
|
||||||
self.textual_scheduler = "ReduceLROnPlateau"
|
self.textual_scheduler = "ReduceLROnPlateau"
|
||||||
|
|
@ -68,10 +65,6 @@ class GeneralizedFunnelling:
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.evaluate_step = evaluate_step
|
self.evaluate_step = evaluate_step
|
||||||
self.device = device
|
self.device = device
|
||||||
# Visual Transformer VGF params ----------
|
|
||||||
self.visual_trf_name = visual_transformer_name
|
|
||||||
self.visual_trf_lr = visual_lr
|
|
||||||
self.visual_scheduler = "ReduceLROnPlateau"
|
|
||||||
# Metaclassifier params ------------
|
# Metaclassifier params ------------
|
||||||
self.optimc = optimc
|
self.optimc = optimc
|
||||||
# -------------------
|
# -------------------
|
||||||
|
|
@ -131,7 +124,6 @@ class GeneralizedFunnelling:
|
||||||
self.multilingual_vgf,
|
self.multilingual_vgf,
|
||||||
self.wce_vgf,
|
self.wce_vgf,
|
||||||
self.textual_trf_vgf,
|
self.textual_trf_vgf,
|
||||||
self.visual_trf_vgf,
|
|
||||||
self.aggfunc,
|
self.aggfunc,
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
@ -174,26 +166,10 @@ class GeneralizedFunnelling:
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
classification_type=self.clf_type,
|
classification_type=self.clf_type,
|
||||||
|
saved_model=self.trained_text_trf,
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(transformer_vgf)
|
self.first_tier_learners.append(transformer_vgf)
|
||||||
|
|
||||||
if self.visual_trf_vgf:
|
|
||||||
visual_trasformer_vgf = VisualTransformerGen(
|
|
||||||
dataset_name=self.dataset_name,
|
|
||||||
model_name="vit",
|
|
||||||
lr=self.visual_trf_lr,
|
|
||||||
scheduler=self.visual_scheduler,
|
|
||||||
epochs=self.epochs,
|
|
||||||
batch_size=self.batch_size_trf,
|
|
||||||
batch_size_eval=self.eval_batch_size_trf,
|
|
||||||
probabilistic=self.probabilistic,
|
|
||||||
evaluate_step=self.evaluate_step,
|
|
||||||
patience=self.patience,
|
|
||||||
device=self.device,
|
|
||||||
classification_type=self.clf_type,
|
|
||||||
)
|
|
||||||
self.first_tier_learners.append(visual_trasformer_vgf)
|
|
||||||
|
|
||||||
if "attn" in self.aggfunc:
|
if "attn" in self.aggfunc:
|
||||||
attn_stacking = self.aggfunc.split("_")[1]
|
attn_stacking = self.aggfunc.split("_")[1]
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
|
|
@ -219,7 +195,6 @@ class GeneralizedFunnelling:
|
||||||
self.multilingual_vgf,
|
self.multilingual_vgf,
|
||||||
self.wce_vgf,
|
self.wce_vgf,
|
||||||
self.textual_trf_vgf,
|
self.textual_trf_vgf,
|
||||||
self.visual_trf_vgf,
|
|
||||||
self.aggfunc,
|
self.aggfunc,
|
||||||
)
|
)
|
||||||
print(f"- model id: {self._model_id}")
|
print(f"- model id: {self._model_id}")
|
||||||
|
|
@ -309,16 +284,29 @@ class GeneralizedFunnelling:
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
def _aggregate_mean(self, first_tier_projections):
|
def _aggregate_mean(self, first_tier_projections):
|
||||||
aggregated = {
|
# aggregated = {
|
||||||
lang: np.zeros(data.shape)
|
# lang: np.zeros(data.shape)
|
||||||
for lang, data in first_tier_projections[0].items()
|
# for lang, data in first_tier_projections[0].items()
|
||||||
}
|
# }
|
||||||
|
aggregated = {}
|
||||||
for lang_projections in first_tier_projections:
|
for lang_projections in first_tier_projections:
|
||||||
for lang, projection in lang_projections.items():
|
for lang, projection in lang_projections.items():
|
||||||
|
if lang not in aggregated:
|
||||||
|
aggregated[lang] = np.zeros(projection.shape)
|
||||||
aggregated[lang] += projection
|
aggregated[lang] += projection
|
||||||
|
|
||||||
for lang, projection in aggregated.items():
|
def get_denom(lang, projs):
|
||||||
aggregated[lang] /= len(first_tier_projections)
|
den = 0
|
||||||
|
for proj in projs:
|
||||||
|
if lang in proj:
|
||||||
|
den += 1
|
||||||
|
return den
|
||||||
|
|
||||||
|
for lang, _ in aggregated.items():
|
||||||
|
# aggregated[lang] /= len(first_tier_projections)
|
||||||
|
aggregated[lang] /= get_denom(lang, first_tier_projections)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
|
|
@ -416,18 +404,6 @@ class GeneralizedFunnelling:
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
print(f"- loaded trained Textual Transformer VGF")
|
print(f"- loaded trained Textual Transformer VGF")
|
||||||
if self.visual_trf_vgf:
|
|
||||||
with open(
|
|
||||||
os.path.join(
|
|
||||||
"models",
|
|
||||||
"vgfs",
|
|
||||||
"visual_transformer",
|
|
||||||
f"visualTransformerGen_{model_id}.pkl",
|
|
||||||
),
|
|
||||||
"rb",
|
|
||||||
print(f"- loaded trained Visual Transformer VGF"),
|
|
||||||
) as vgf:
|
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
|
||||||
|
|
||||||
if load_meta:
|
if load_meta:
|
||||||
with open(
|
with open(
|
||||||
|
|
@ -482,7 +458,6 @@ def get_unique_id(
|
||||||
multilingual,
|
multilingual,
|
||||||
wce,
|
wce,
|
||||||
textual_transformer,
|
textual_transformer,
|
||||||
visual_transformer,
|
|
||||||
aggfunc,
|
aggfunc,
|
||||||
):
|
):
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -493,6 +468,5 @@ def get_unique_id(
|
||||||
model_id += "m" if multilingual else ""
|
model_id += "m" if multilingual else ""
|
||||||
model_id += "w" if wce else ""
|
model_id += "w" if wce else ""
|
||||||
model_id += "t" if textual_transformer else ""
|
model_id += "t" if textual_transformer else ""
|
||||||
model_id += "v" if visual_transformer else ""
|
|
||||||
model_id += f"_{aggfunc}"
|
model_id += f"_{aggfunc}"
|
||||||
return f"{model_id}_{now}"
|
return f"{model_id}_{now}"
|
||||||
|
|
|
||||||
|
|
@ -173,9 +173,18 @@ class NaivePolylingualClassifier:
|
||||||
:return: a dictionary of probabilities that each document belongs to each class
|
:return: a dictionary of probabilities that each document belongs to each class
|
||||||
"""
|
"""
|
||||||
assert self.model is not None, "predict called before fit"
|
assert self.model is not None, "predict called before fit"
|
||||||
assert set(lX.keys()).issubset(
|
if not set(lX.keys()).issubset(set(self.model.keys())):
|
||||||
set(self.model.keys())
|
langs = set(lX.keys()).intersection(set(self.model.keys()))
|
||||||
), "unknown languages requested in decision function"
|
scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
|
||||||
|
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs)
|
||||||
|
# res = {lang: None for lang in lX.keys()}
|
||||||
|
# for i, lang in enumerate(langs):
|
||||||
|
# res[lang] = scores[i]
|
||||||
|
# return res
|
||||||
|
return {lang: scores[i] for i, lang in enumerate(langs)}
|
||||||
|
# assert set(lX.keys()).issubset(
|
||||||
|
# set(self.model.keys())
|
||||||
|
# ), "unknown languages requested in decision function"
|
||||||
langs = list(lX.keys())
|
langs = list(lX.keys())
|
||||||
scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
|
scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
|
||||||
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs
|
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
patience=5,
|
patience=5,
|
||||||
classification_type="multilabel",
|
classification_type="multilabel",
|
||||||
scheduler="ReduceLROnPlateau",
|
scheduler="ReduceLROnPlateau",
|
||||||
|
saved_model = None
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self._validate_model_name(model_name),
|
self._validate_model_name(model_name),
|
||||||
|
|
@ -91,6 +92,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
n_jobs=n_jobs,
|
n_jobs=n_jobs,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
self.saved_model = saved_model
|
||||||
self.clf_type = classification_type
|
self.clf_type = classification_type
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
print(
|
print(
|
||||||
|
|
@ -109,26 +111,25 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def load_pretrained_model(self, model_name, num_labels):
|
def load_pretrained_model(self, model_name, num_labels, saved_model=None):
|
||||||
if model_name == "google/mt5-small":
|
if model_name == "google/mt5-small":
|
||||||
return MT5ForSequenceClassification(
|
return MT5ForSequenceClassification(
|
||||||
model_name, num_labels=num_labels, output_hidden_states=True
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
|
if saved_model:
|
||||||
# model_name = "hf_models/mbert-rai-fewshot-second/checkpoint-19000" # TODO hardcoded to pre-traiend mbert
|
model_name = saved_model
|
||||||
# model_name = "hf_models/mbert-sentiment/checkpoint-1150" # TODO hardcoded to pre-traiend mbert
|
else:
|
||||||
model_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
|
model_name = "google/bert-base-multilingual-cased"
|
||||||
return AutoModelForSequenceClassification.from_pretrained(
|
return AutoModelForSequenceClassification.from_pretrained(
|
||||||
model_name, num_labels=num_labels, output_hidden_states=True
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_tokenizer(self, model_name):
|
def load_tokenizer(self, model_name):
|
||||||
# model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
|
|
||||||
return AutoTokenizer.from_pretrained(model_name)
|
return AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
def init_model(self, model_name, num_labels):
|
def init_model(self, model_name, num_labels, saved_model):
|
||||||
return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer(
|
return self.load_pretrained_model(model_name, num_labels, saved_model), self.load_tokenizer(
|
||||||
model_name
|
model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -148,64 +149,11 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
_l = list(lX.keys())[0]
|
_l = list(lX.keys())[0]
|
||||||
self.num_labels = lY[_l].shape[-1]
|
self.num_labels = lY[_l].shape[-1]
|
||||||
self.model, self.tokenizer = self.init_model(
|
self.model, self.tokenizer = self.init_model(
|
||||||
self.model_name, num_labels=self.num_labels
|
self.model_name, num_labels=self.num_labels, saved_model=self.saved_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model.to("cuda")
|
self.model.to("cuda")
|
||||||
|
|
||||||
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
|
||||||
# lX, lY, split=0.2, seed=42, modality="text"
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# tra_dataloader = self.build_dataloader(
|
|
||||||
# tr_lX,
|
|
||||||
# tr_lY,
|
|
||||||
# processor_fn=self._tokenize,
|
|
||||||
# torchDataset=MultilingualDatasetTorch,
|
|
||||||
# batch_size=self.batch_size,
|
|
||||||
# split="train",
|
|
||||||
# shuffle=True,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# val_dataloader = self.build_dataloader(
|
|
||||||
# val_lX,
|
|
||||||
# val_lY,
|
|
||||||
# processor_fn=self._tokenize,
|
|
||||||
# torchDataset=MultilingualDatasetTorch,
|
|
||||||
# batch_size=self.batch_size_eval,
|
|
||||||
# split="val",
|
|
||||||
# shuffle=False,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
|
||||||
#
|
|
||||||
# trainer = Trainer(
|
|
||||||
# model=self.model,
|
|
||||||
# optimizer_name="adamW",
|
|
||||||
# lr=self.lr,
|
|
||||||
# device=self.device,
|
|
||||||
# loss_fn=torch.nn.CrossEntropyLoss(),
|
|
||||||
# print_steps=self.print_steps,
|
|
||||||
# evaluate_step=self.evaluate_step,
|
|
||||||
# patience=self.patience,
|
|
||||||
# experiment_name=experiment_name,
|
|
||||||
# checkpoint_path=os.path.join(
|
|
||||||
# "models",
|
|
||||||
# "vgfs",
|
|
||||||
# "trained_transformer",
|
|
||||||
# self._format_model_name(self.model_name),
|
|
||||||
# ),
|
|
||||||
# vgf_name="textual_trf",
|
|
||||||
# classification_type=self.clf_type,
|
|
||||||
# n_jobs=self.n_jobs,
|
|
||||||
# scheduler_name=self.scheduler,
|
|
||||||
# )
|
|
||||||
# trainer.train(
|
|
||||||
# train_dataloader=tra_dataloader,
|
|
||||||
# eval_dataloader=val_dataloader,
|
|
||||||
# epochs=self.epochs,
|
|
||||||
# )
|
|
||||||
|
|
||||||
if self.probabilistic:
|
if self.probabilistic:
|
||||||
transformed = self.transform(lX)
|
transformed = self.transform(lX)
|
||||||
self.feature2posterior_projector.fit(transformed, lY)
|
self.feature2posterior_projector.fit(transformed, lY)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from os.path import expanduser
|
from os.path import expanduser, join
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
|
@ -19,13 +19,8 @@ import pandas as pd
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"]
|
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"]
|
||||||
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] # "lang"
|
|
||||||
WEBIS_D_COLUMNS = ['Unnamed: 0', 'asin', 'category', 'original_rating', 'label', 'title', 'text', 'summary'] # "lang"
|
|
||||||
MAX_LEN = 128
|
MAX_LEN = 128
|
||||||
# DATASET_NAME = "rai"
|
|
||||||
# DATASET_NAME = "rai-multilingual-2000"
|
|
||||||
# DATASET_NAME = "webis-cls"
|
|
||||||
|
|
||||||
|
|
||||||
def init_callbacks(patience=-1, nosave=False):
|
def init_callbacks(patience=-1, nosave=False):
|
||||||
|
|
@ -35,13 +30,17 @@ def init_callbacks(patience=-1, nosave=False):
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
def init_model(model_name, nlabels):
|
def init_model(model_name, nlabels, saved_model=None):
|
||||||
if model_name == "mbert":
|
if model_name == "mbert":
|
||||||
# hf_name = "bert-base-multilingual-cased"
|
if saved_model is None:
|
||||||
hf_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
|
hf_name = "bert-base-multilingual-cased"
|
||||||
# hf_name = "hf_models/mbert-rai-fewshot-second/checkpoint-9000"
|
else:
|
||||||
|
hf_name = saved_model
|
||||||
elif model_name == "xlm-roberta":
|
elif model_name == "xlm-roberta":
|
||||||
|
if saved_model is None:
|
||||||
hf_name = "xlm-roberta-base"
|
hf_name = "xlm-roberta-base"
|
||||||
|
else:
|
||||||
|
hf_name = saved_model
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
tokenizer = AutoTokenizer.from_pretrained(hf_name)
|
tokenizer = AutoTokenizer.from_pretrained(hf_name)
|
||||||
|
|
@ -50,43 +49,41 @@ def init_model(model_name, nlabels):
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
tokenizer, model = init_model(args.model, args.nlabels)
|
saved_model = args.savedmodel
|
||||||
|
trainlang = args.trainlangs
|
||||||
|
datapath = args.datapath
|
||||||
|
|
||||||
|
tokenizer, model = init_model(args.model, args.nlabels, saved_model=saved_model)
|
||||||
|
|
||||||
data = load_dataset(
|
data = load_dataset(
|
||||||
"csv",
|
"csv",
|
||||||
data_files = {
|
data_files = {
|
||||||
"train": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/train.balanced.csv"),
|
"train": expanduser(join(datapath, "train.csv")),
|
||||||
"test": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/test.balanced.csv")
|
"test": expanduser(join(datapath, "test.small.csv"))
|
||||||
# "train": expanduser(f"~/datasets/rai/csv/train-{DATASET_NAME}.csv"),
|
|
||||||
# "test": expanduser(f"~/datasets/rai/csv/test-{DATASET_NAME}.csv")
|
|
||||||
# "train": expanduser(f"~/datasets/rai/csv/train.small.csv"),
|
|
||||||
# "test": expanduser(f"~/datasets/rai/csv/test.small.csv")
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def filter_dataset(dataset, lang):
|
||||||
|
indices = [i for i, l in enumerate(dataset["lang"]) if l == lang]
|
||||||
|
dataset = dataset.select(indices)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
if trainlang is not None:
|
||||||
|
data["train"] = filter_dataset(data["train"], lang=trainlang)
|
||||||
|
|
||||||
def process_sample_rai(sample):
|
def process_sample_rai(sample):
|
||||||
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
|
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
|
||||||
labels = sample["label"]
|
labels = sample["label"]
|
||||||
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
|
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True)
|
||||||
model_inputs["labels"] = labels
|
model_inputs["labels"] = labels
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def process_sample_webis(sample):
|
|
||||||
inputs = sample["text"]
|
|
||||||
labels = sample["label"]
|
|
||||||
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
|
|
||||||
data = data.map(
|
data = data.map(
|
||||||
# process_sample_rai,
|
process_sample_rai,
|
||||||
process_sample_webis,
|
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=4,
|
num_proc=4,
|
||||||
load_from_cache_file=True,
|
load_from_cache_file=True,
|
||||||
# remove_columns=RAI_D_COLUMNS,
|
remove_columns=RAI_D_COLUMNS,
|
||||||
remove_columns=WEBIS_D_COLUMNS,
|
|
||||||
)
|
)
|
||||||
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
||||||
data.set_format("torch")
|
data.set_format("torch")
|
||||||
|
|
@ -107,8 +104,8 @@ def main(args):
|
||||||
recall_metric = evaluate.load("recall")
|
recall_metric = evaluate.load("recall")
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
# output_dir=f"hf_models/{args.model}-rai",
|
output_dir=f"hf_models/{args.model}-fewshot-full" if trainlang is None else f"hf_models/{args.model}-zeroshot-full",
|
||||||
output_dir=f"hf_models/{args.model}-sentiment-balanced",
|
run_name="model-zeroshot" if trainlang is not None else "model-fewshot",
|
||||||
do_train=True,
|
do_train=True,
|
||||||
evaluation_strategy="steps",
|
evaluation_strategy="steps",
|
||||||
per_device_train_batch_size=args.batch,
|
per_device_train_batch_size=args.batch,
|
||||||
|
|
@ -130,8 +127,6 @@ def main(args):
|
||||||
save_strategy="no" if args.nosave else "steps",
|
save_strategy="no" if args.nosave else "steps",
|
||||||
save_total_limit=2,
|
save_total_limit=2,
|
||||||
eval_steps=args.stepeval,
|
eval_steps=args.stepeval,
|
||||||
# run_name=f"{args.model}-rai-stratified",
|
|
||||||
run_name=f"{args.model}-sentiment",
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
log_level="warning",
|
log_level="warning",
|
||||||
report_to=["wandb"] if args.wandb else "none",
|
report_to=["wandb"] if args.wandb else "none",
|
||||||
|
|
@ -142,7 +137,6 @@ def main(args):
|
||||||
|
|
||||||
def compute_metrics(eval_preds):
|
def compute_metrics(eval_preds):
|
||||||
preds = eval_preds.predictions.argmax(-1)
|
preds = eval_preds.predictions.argmax(-1)
|
||||||
# targets = eval_preds.label_ids.argmax(-1)
|
|
||||||
targets = eval_preds.label_ids
|
targets = eval_preds.label_ids
|
||||||
setting = "macro"
|
setting = "macro"
|
||||||
f1_score_macro = f1_metric.compute(
|
f1_score_macro = f1_metric.compute(
|
||||||
|
|
@ -170,7 +164,9 @@ def main(args):
|
||||||
|
|
||||||
if args.wandb:
|
if args.wandb:
|
||||||
import wandb
|
import wandb
|
||||||
wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-rai", config=vars(args))
|
wandb.init(entity="andreapdr", project=f"gfun",
|
||||||
|
name="model-zeroshot-full" if trainlang is not None else "model-fewshot-full",
|
||||||
|
config=vars(args))
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -188,17 +184,21 @@ def main(args):
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
print("- Testing:")
|
print("- Testing:")
|
||||||
|
test_results = trainer.evaluate(eval_dataset=data["test"], metric_key_prefix="test")
|
||||||
test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test")
|
test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test")
|
||||||
pprint(test_results.metrics)
|
pprint(test_results.metrics)
|
||||||
save_preds(data["test"], test_results.predictions)
|
save_preds(data["test"], test_results.predictions, trainlang)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
def save_preds(dataset, predictions):
|
def save_preds(dataset, predictions, trainlang=None):
|
||||||
df = pd.DataFrame()
|
df = pd.DataFrame()
|
||||||
df["langs"] = dataset["lang"]
|
df["langs"] = dataset["lang"]
|
||||||
df["labels"] = dataset["labels"]
|
df["labels"] = dataset["labels"]
|
||||||
df["preds"] = predictions.argmax(axis=1)
|
df["preds"] = predictions.argmax(axis=1)
|
||||||
df.to_csv("results/lang-specific.mbert.webis.csv", index=False)
|
if trainlang is not None:
|
||||||
|
df.to_csv(f"results/zeroshot.{trainlang}.model.csv", index=False)
|
||||||
|
else:
|
||||||
|
df.to_csv("results/fewshot.model.csv", index=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -212,7 +212,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
||||||
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
|
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
|
||||||
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
||||||
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
|
parser.add_argument("--epochs", type=int, metavar="", default=10, help="Set epochs")
|
||||||
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
|
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
|
||||||
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
|
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
|
||||||
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
|
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
|
||||||
|
|
@ -220,6 +220,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
|
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
|
||||||
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
|
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
|
||||||
parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
|
parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
|
||||||
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
|
parser.add_argument("--trainlang", default=None, type=str, help="set training language for zero-shot experiments" )
|
||||||
|
parser.add_argument("--datapath", type=str, default="data", help="path to the csv dataset. Dir should contain both a train.csv and a test.csv file")
|
||||||
|
parser.add_argument("--savedmodel", type=str, default="hf_models/mbert-rai-fewshot-second/checkpoint-9000")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
||||||
30
main.py
30
main.py
|
|
@ -10,8 +10,6 @@ import pandas as pd
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- General:
|
|
||||||
[!] zero-shot setup
|
|
||||||
- Docs:
|
- Docs:
|
||||||
- add documentations sphinx
|
- add documentations sphinx
|
||||||
"""
|
"""
|
||||||
|
|
@ -27,13 +25,11 @@ def get_config_name(args):
|
||||||
config_name += "M+"
|
config_name += "M+"
|
||||||
if args.textual_transformer:
|
if args.textual_transformer:
|
||||||
config_name += f"TT_{args.textual_trf_name}+"
|
config_name += f"TT_{args.textual_trf_name}+"
|
||||||
if args.visual_transformer:
|
|
||||||
config_name += f"VT_{args.visual_trf_name}+"
|
|
||||||
return config_name.rstrip("+")
|
return config_name.rstrip("+")
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = get_dataset(args.dataset, args)
|
dataset = get_dataset(args.datadir, args)
|
||||||
lX, lY = dataset.training(merge_validation=True)
|
lX, lY = dataset.training(merge_validation=True)
|
||||||
lX_te, lY_te = dataset.test()
|
lX_te, lY_te = dataset.test()
|
||||||
|
|
||||||
|
|
@ -47,13 +43,12 @@ def main(args):
|
||||||
args.multilingual,
|
args.multilingual,
|
||||||
args.multilingual,
|
args.multilingual,
|
||||||
args.textual_transformer,
|
args.textual_transformer,
|
||||||
args.visual_transformer,
|
|
||||||
]
|
]
|
||||||
), "At least one of VGF must be True"
|
), "At least one of VGF must be True"
|
||||||
|
|
||||||
gfun = GeneralizedFunnelling(
|
gfun = GeneralizedFunnelling(
|
||||||
# dataset params ----------------------
|
# dataset params ----------------------
|
||||||
dataset_name=args.dataset,
|
dataset_name=dataset,
|
||||||
langs=dataset.langs(),
|
langs=dataset.langs(),
|
||||||
num_labels=dataset.num_labels(),
|
num_labels=dataset.num_labels(),
|
||||||
classification_type=args.clf_type,
|
classification_type=args.clf_type,
|
||||||
|
|
@ -67,24 +62,15 @@ def main(args):
|
||||||
# Transformer VGF params --------------
|
# Transformer VGF params --------------
|
||||||
textual_transformer=args.textual_transformer,
|
textual_transformer=args.textual_transformer,
|
||||||
textual_transformer_name=args.textual_trf_name,
|
textual_transformer_name=args.textual_trf_name,
|
||||||
|
trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
eval_batch_size=args.eval_batch_size,
|
eval_batch_size=args.eval_batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
textual_lr=args.textual_lr,
|
textual_lr=args.textual_lr,
|
||||||
visual_lr=args.visual_lr,
|
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
patience=args.patience,
|
patience=args.patience,
|
||||||
evaluate_step=args.evaluate_step,
|
evaluate_step=args.evaluate_step,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
# Visual Transformer VGF params --------------
|
|
||||||
visual_transformer=args.visual_transformer,
|
|
||||||
visual_transformer_name=args.visual_trf_name,
|
|
||||||
# batch_size=args.batch_size,
|
|
||||||
# epochs=args.epochs,
|
|
||||||
# lr=args.lr,
|
|
||||||
# patience=args.patience,
|
|
||||||
# evaluate_step=args.evaluate_step,
|
|
||||||
# device="cuda",
|
|
||||||
# General params ---------------------
|
# General params ---------------------
|
||||||
probabilistic=args.features,
|
probabilistic=args.features,
|
||||||
aggfunc=args.aggfunc,
|
aggfunc=args.aggfunc,
|
||||||
|
|
@ -145,7 +131,7 @@ def main(args):
|
||||||
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||||
|
|
||||||
config["gFun"]["timing"] = f"{timeval - tinit:.2f}"
|
config["gFun"]["timing"] = f"{timeval - tinit:.2f}"
|
||||||
csvlogger = CsvLogger(outfile="results/random.log.csv").log_lang_results(lang_metrics_gfun, config)
|
csvlogger = CsvLogger(outfile="results/gfun.log.csv").log_lang_results(lang_metrics_gfun, config, notes="")
|
||||||
save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"])
|
save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -162,7 +148,7 @@ def save_preds(preds, targets, config="unk", dataset="unk"):
|
||||||
df["langs"] = _langs
|
df["langs"] = _langs
|
||||||
df["labels"] = _targets
|
df["labels"] = _targets
|
||||||
df["preds"] = _preds
|
df["preds"] = _preds
|
||||||
df.to_csv(f"results/lang-specific.gfun.{config}.{dataset}.csv", index=False)
|
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -174,7 +160,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--tr_langs", nargs="+", default=None)
|
parser.add_argument("--tr_langs", nargs="+", default=None)
|
||||||
parser.add_argument("--te_langs", nargs="+", default=None)
|
parser.add_argument("--te_langs", nargs="+", default=None)
|
||||||
# Dataset parameters -------------------
|
# Dataset parameters -------------------
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
|
parser.add_argument("-d", "--datadir", type=str, default=None, help="dir to dataset. It should contain both a train.csv and a test.csv file")
|
||||||
parser.add_argument("--domains", type=str, default="all")
|
parser.add_argument("--domains", type=str, default="all")
|
||||||
parser.add_argument("--nrows", type=int, default=None)
|
parser.add_argument("--nrows", type=int, default=None)
|
||||||
parser.add_argument("--min_count", type=int, default=10)
|
parser.add_argument("--min_count", type=int, default=10)
|
||||||
|
|
@ -186,7 +172,6 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||||
parser.add_argument("-w", "--wce", action="store_true")
|
parser.add_argument("-w", "--wce", action="store_true")
|
||||||
parser.add_argument("-t", "--textual_transformer", action="store_true")
|
parser.add_argument("-t", "--textual_transformer", action="store_true")
|
||||||
parser.add_argument("-v", "--visual_transformer", action="store_true")
|
|
||||||
parser.add_argument("--n_jobs", type=int, default=-1)
|
parser.add_argument("--n_jobs", type=int, default=-1)
|
||||||
parser.add_argument("--optimc", action="store_true")
|
parser.add_argument("--optimc", action="store_true")
|
||||||
parser.add_argument("--features", action="store_false")
|
parser.add_argument("--features", action="store_false")
|
||||||
|
|
@ -200,9 +185,6 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--max_length", type=int, default=128)
|
parser.add_argument("--max_length", type=int, default=128)
|
||||||
parser.add_argument("--patience", type=int, default=5)
|
parser.add_argument("--patience", type=int, default=5)
|
||||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||||
# Visual Transformer parameters --------------
|
|
||||||
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
|
||||||
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
|
||||||
# logging
|
# logging
|
||||||
parser.add_argument("--wandb", action="store_true")
|
parser.add_argument("--wandb", action="store_true")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue