bulk update: zero-shot + csvlogger + simpler dataset class + rai experiments

This commit is contained in:
Andrea Pedrotti 2023-08-03 19:31:03 +02:00
parent ae92199613
commit fbd740fabd
16 changed files with 674 additions and 102 deletions

60
compute_results.py Normal file
View File

@ -0,0 +1,60 @@
from argparse import ArgumentParser
from csvlogger import CsvLogger
import pandas as pd
from sklearn.metrics import mean_absolute_error
from os.path import join
"""
MEA and classification is meaningful only in "ordinal" tasks e.g., sentiment classification.
Otherwise the distance between the categories has no semantics!
- NB: we want to get the macro-averaged class specific MAE!
"""
def main():
# SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"]
SETTINGS = ["mbert"]
results = []
for setting in SETTINGS:
results.append(evalaute(setting))
df = pd.DataFrame()
for r in results:
df = df.append(r)
print(df)
def evalaute(setting):
result_dir = "results"
# result_file = f"lang-specific.gfun.{setting}.webis.csv"
result_file = f"lang-specific.mbert.webis.csv"
# print(f"- reading from: {result_file}")
df = pd.read_csv(join(result_dir, result_file))
langs = df.langs.unique()
res = []
for lang in langs:
l_df = df.langs == lang
selected_neg = df.labels == 0
seleteced_neutral = df.labels == 1
selected_pos = df.labels == 2
neg = df[l_df & selected_neg]
neutral = df[l_df & seleteced_neutral]
pos = df[l_df & selected_pos]
# print(f"{lang=}")
# print(neg.shape, neutral.shape, pos.shape)
neg_mae = mean_absolute_error(neg.labels, neg.preds).round(3)
neutral_mae = mean_absolute_error(neutral.labels, neutral.preds).round(3)
pos_mae = mean_absolute_error(pos.labels, pos.preds).round(3)
macro_mae = ((neg_mae + neutral_mae + pos_mae) / 3).round(3)
# print(f"{lang=} - {neg_mae=}, {neutral_mae=}, {pos_mae=}, {macro_mae=}")
res.append([lang, neg_mae, neutral_mae, pos_mae, setting])
return res
if __name__ == "__main__":
main()

31
csvlogger.py Normal file
View File

@ -0,0 +1,31 @@
import csv
import pandas as pd
import os
class CsvLogger:
def __init__(self, outfile="log.csv"):
self.outfile = outfile
# self.init_logfile()
# def init_logfile(self):
# if not os.path.isfile(self.outfile.replace(".csv", ".avg.csv")):
# os.makedirs(self.outfile.replace(".csv", ".avg.csv"), exist_ok=True)
# if not os.path.isfile(self.outfile.replace(".csv", ".lang.avg.csv")):
# os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True)
# return
def log_lang_results(self, results: dict, config="gfun-lello"):
df = pd.DataFrame.from_dict(results, orient="columns")
df["config"] = config["gFun"]["simple_id"]
df["aggfunc"] = config["gFun"]["aggfunc"]
df["dataset"] = config["gFun"]["dataset"]
df["id"] = config["gFun"]["id"]
df["optimc"] = config["gFun"]["optimc"]
df["timing"] = config["gFun"]["timing"]
with open(self.outfile, 'a') as f:
df.to_csv(f, mode='a', header=f.tell()==0)

View File

@ -40,9 +40,11 @@ def load_unprocessed_cls(reduce_target_space=False):
if reduce_target_space:
rating = np.zeros(3, dtype=int)
original_rating = int(float(child.find("rating").text))
if original_rating < 3:
# if original_rating < 3:
if original_rating < 2:
new_rating = 1
elif original_rating > 3:
# elif original_rating > 3:
elif original_rating > 4:
new_rating = 3
else:
new_rating = 2
@ -73,7 +75,8 @@ def load_unprocessed_cls(reduce_target_space=False):
# "rating": child.find("rating").text
# if child.find("rating") is not None
# else None,
"rating": rating,
"original_rating": int(float(child.find("rating").text)),
"rating": rating.argmax(),
"title": child.find("title").text
if child.find("title") is not None
else None,
@ -171,8 +174,8 @@ if __name__ == "__main__":
tr_ids=None,
te_ids=None,
)
multilingualDataset.save(
os.path.expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
)
)
# multilingualDataset.save(
# os.path.expanduser(
# "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
# )
# )

View File

@ -1,10 +1,134 @@
import sys
import os
sys.path.append(os.path.expanduser("~/devel/gfun_multimodal"))
from collections import defaultdict, Counter
import numpy as np
import re
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDataset import MultilingualDataset
class SimpleGfunDataset:
def __init__(self, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None):
self.datadir = os.path.expanduser(datadir)
self.textual = textual
self.visual = visual
self.multilabel = multilabel
self.load_csv(set_tr_langs, set_te_langs)
self.print_stats()
def print_stats(self):
print(f"Dataset statistics {'-' * 15}")
tr = 0
va = 0
te = 0
for lang in self.all_langs:
n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0
n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0
n_te = len(self.test_data[lang])
tr += n_tr
va += n_va
te += n_te
print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}")
print(f"Total {'-' * 15}")
print(f"tr: {tr} - va: {va} - te: {te}")
def load_csv(self, set_tr_langs, set_te_langs):
# _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.balanced.csv")).sample(100, random_state=42)
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) # TODO stratify on lang or label?
# test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
test = pd.read_csv(os.path.join(self.datadir, "test.balanced.csv")).sample(100, random_state=42)
self._set_langs (train, test, set_tr_langs, set_te_langs)
self._set_labels(_data_tr)
self.full_train = _data_tr
self.full_test = self.test
self.train_data = self._set_datalang(train)
self.val_data = self._set_datalang(val)
self.test_data = self._set_datalang(test)
return
def _set_labels(self, data):
# self.labels = [i for i in range(28)] # todo hard-coded for rai
# self.labels = [i for i in range(3)] # TODO hard coded for sentimnet
self.labels = sorted(list(data.label.unique()))
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
self.tr_langs = set(train.lang.unique().tolist())
self.te_langs = set(test.lang.unique().tolist())
if set_tr_langs is not None:
print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]")
self.tr_langs = self.tr_langs.intersection(set(set_tr_langs))
if set_te_langs is not None:
print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]")
self.te_langs = self.te_langs.intersection(set(set_te_langs))
self.all_langs = self.tr_langs.union(self.te_langs)
return self.tr_langs, self.te_langs, self.all_langs
def _set_datalang(self, data: pd.DataFrame):
return {lang: data[data.lang == lang] for lang in self.all_langs}
def training(self, merge_validation=False, mask_number=False, target_as_csr=False):
# TODO some additional pre-processing on the textual data?
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXtr = {
lang: {"text": apply_mask(self.train_data[lang].text.tolist())} # TODO inserting dict for textual data - we still have to manage visual
for lang in self.tr_langs
}
if merge_validation:
for lang in self.tr_langs:
lXtr[lang]["text"] += apply_mask(self.val_data[lang].text.tolist())
lYtr = {
lang: self.train_data[lang].label.tolist() for lang in self.tr_langs
}
if merge_validation:
for lang in self.tr_langs:
lYtr[lang] += self.val_data[lang].label.tolist()
for lang in self.tr_langs:
lYtr[lang] = self.indices_to_one_hot(
indices = lYtr[lang],
n_labels = self.num_labels()
)
return lXtr, lYtr
def test(self, mask_number=False, target_as_csr=False):
# TODO some additional pre-processing on the textual data?
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXte = {
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
for lang in self.te_langs
}
lYte = {
lang: self.indices_to_one_hot(
indices=self.test_data[lang].label.tolist(),
n_labels=self.num_labels())
for lang in self.te_langs
}
return lXte, lYte
def langs(self):
return list(self.all_langs)
def num_labels(self):
return len(self.labels)
def indices_to_one_hot(self, indices, n_labels):
one_hot_matrix = np.zeros((len(indices), n_labels))
one_hot_matrix[np.arange(len(indices)), indices] = 1
return one_hot_matrix
class gFunDataset:
def __init__(
self,
@ -85,7 +209,7 @@ class gFunDataset:
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
elif "rai" in self.dataset_dir.lower():
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
self.dataset_name = "rai"
@ -111,8 +235,10 @@ class gFunDataset:
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
if "csv" in dataset_dir:
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv(
path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
# path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
#path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
path_tr="~/datasets/rai/csv/train-split-rai.csv",
path_te="~/datasets/rai/csv/test-split-rai.csv")
else:
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
if nrows is not None:
@ -218,28 +344,48 @@ class gFunDataset:
print(f"- saving dataset in {filepath}")
pickle.dump(self, f)
def _mask_numbers(data):
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b")
mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b")
mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b")
masked = []
for text in tqdm(data, desc="masking numbers", disable=True):
text = " " + text
text = mask_moredigit.sub(" MoreDigitMask", text)
text = mask_4digit.sub(" FourDigitMask", text)
text = mask_3digit.sub(" ThreeDigitMask", text)
text = mask_2digit.sub(" TwoDigitMask", text)
text = mask_1digit.sub(" OneDigitMask", text)
masked.append(text.replace(".", "").replace(",", "").strip())
return masked
if __name__ == "__main__":
import os
data_rai = SimpleGfunDataset()
lXtr, lYtr = data_rai.training(mask_number=False)
lXte, lYte = data_rai.test(mask_number=False)
exit()
# import os
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
RCV_DATAPATH = os.path.expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = os.path.expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
# GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
# RCV_DATAPATH = os.path.expanduser(
# "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
# )
# JRC_DATAPATH = os.path.expanduser(
# "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
# )
print("Hello gFunDataset")
dataset = gFunDataset(
dataset_dir=JRC_DATAPATH,
data_langs=None,
is_textual=True,
is_visual=True,
is_multilabel=False,
labels=None,
nrows=13,
)
lXtr, lYtr = dataset.training()
lXte, lYte = dataset.test()
exit(0)
# print("Hello gFunDataset")
# dataset = gFunDataset(
# dataset_dir=JRC_DATAPATH,
# data_langs=None,
# is_textual=True,
# is_visual=True,
# is_multilabel=False,
# labels=None,
# nrows=13,
# )
# lXtr, lYtr = dataset.training()
# lXte, lYte = dataset.test()
# exit(0)

View File

@ -1,5 +1,5 @@
from os.path import expanduser, join
from dataManager.gFunDataset import gFunDataset
from dataManager.gFunDataset import gFunDataset, SimpleGfunDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset
@ -40,7 +40,8 @@ def get_dataset(dataset_name, args):
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
WEBIS_CLS = expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
# "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
"~/datasets/cls-acl10-unprocessed/csv"
)
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
@ -99,21 +100,35 @@ def get_dataset(dataset_name, args):
nrows=args.nrows,
)
elif dataset_name == "webis":
dataset = gFunDataset(
dataset_dir=WEBIS_CLS,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
dataset = SimpleGfunDataset(
datadir=WEBIS_CLS,
textual=True,
visual=False,
multilabel=False,
set_tr_langs=args.tr_langs,
set_te_langs=args.te_langs
)
# dataset = gFunDataset(
# dataset_dir=WEBIS_CLS,
# is_textual=True,
# is_visual=False,
# is_multilabel=False,
# nrows=args.nrows,
# )
elif dataset_name == "rai":
dataset = gFunDataset(
dataset_dir=RAI_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows
dataset = SimpleGfunDataset(
datadir="~/datasets/rai/csv",
textual=True,
visual=False,
multilabel=False
)
# dataset = gFunDataset(
# dataset_dir=RAI_DATAPATH,
# is_textual=True,
# is_visual=False,
# is_multilabel=False,
# nrows=args.nrows
# )
else:
raise NotImplementedError
return dataset

View File

@ -52,7 +52,7 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
metrics = []
if clf_type == "multilabel":
for lang in l_eval.keys():
for lang in sorted(l_eval.keys()):
# macrof1, microf1, macrok, microk = l_eval[lang]
# metrics.append([macrof1, microf1, macrok, microk])
macrof1, microf1, precision, recall = l_eval[lang]
@ -79,7 +79,7 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
"precision",
"recall"
]
for lang in l_eval.keys():
for lang in sorted(l_eval.keys()):
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
acc, macrof1, microf1, precision, recall= l_eval[lang]
# metrics.append([acc, top5, top10, macrof1, microf1])

View File

@ -251,7 +251,7 @@ class GeneralizedFunnelling:
self.metaclassifier.fit(agg, lY)
return self
self.vectorizer.fit(lX)
self.vectorizer.fit(lX) # TODO this should fit also out-of-voc languages (for muses)
self.init_vgfs_vectorizers()
projections = []
@ -324,16 +324,19 @@ class GeneralizedFunnelling:
def get_config(self):
c = {}
simple_config = ""
for vgf in self.first_tier_learners:
vgf_config = vgf.get_config()
c.update({vgf_config["name"]: vgf_config})
simple_config += vgf_config["simple_id"]
gfun_config = {
"id": self._model_id,
"aggfunc": self.aggfunc,
"optimc": self.optimc,
"dataset": self.dataset_name,
"simple_id": "".join(sorted(simple_config))
}
c["gFun"] = gfun_config

View File

@ -103,6 +103,11 @@ def predict(logits, clf_type="multilabel"):
class TfidfVectorizerMultilingual:
def __init__(self, **kwargs):
self.kwargs = kwargs
def update(self, X, lang):
self.langs.append(lang)
self.vectorizer[lang] = TfidfVectorizer(**self.kwargs).fit(X["text"])
return self
def fit(self, lX, ly=None):
self.langs = sorted(lX.keys())
@ -112,7 +117,12 @@ class TfidfVectorizerMultilingual:
return self
def transform(self, lX):
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs}
in_langs = lX.keys()
for in_l in in_langs:
if in_l not in self.langs:
print(f"[NB: found unvectorized language! Updatding vectorizer for {in_l=}]")
self.update(X=lX[in_l], lang=in_l)
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} # TODO we can update the vectorizer with new languages here!
def fit_transform(self, lX, ly=None):
return self.fit(lX, ly).transform(lX)

View File

@ -56,6 +56,13 @@ class MultilingualGen(ViewGen):
def transform(self, lX):
lX = self.vectorizer.transform(lX)
if self.langs != sorted(self.vectorizer.vectorizer.keys()):
# new_langs = set(self.vectorizer.vectorizer.keys()) - set(self.langs)
old_langs = self.langs
self.langs = sorted(self.vectorizer.vectorizer.keys())
new_load, _ = self._load_embeddings(embed_dir=self.embed_dir, cached=self.cached, exclude=old_langs)
for k, v in new_load.items():
self.multi_embeddings[k] = v
XdotMulti = Parallel(n_jobs=self.n_jobs)(
delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif)
@ -70,10 +77,12 @@ class MultilingualGen(ViewGen):
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def _load_embeddings(self, embed_dir, cached):
def _load_embeddings(self, embed_dir, cached, exclude=None):
if "muse" in self.embed_dir.lower():
if exclude is not None:
langs = set(self.langs) - set(exclude)
multi_embeddings = load_MUSEs(
langs=self.langs,
langs=self.langs if exclude is None else langs,
l_vocab=self.vectorizer.vocabulary(),
dir_path=embed_dir,
cached=cached,
@ -89,6 +98,7 @@ class MultilingualGen(ViewGen):
"cached": self.cached,
"sif": self.sif,
"probabilistic": self.probabilistic,
"simple_id": "m"
}
def save_vgf(self, model_id):
@ -164,6 +174,8 @@ def extract(l_voc, l_embeddings):
"""
l_extracted = {}
for lang, words in l_voc.items():
if lang not in l_embeddings:
continue
source_id, target_id = reindex(words, l_embeddings[lang].stoi)
extraction = torch.zeros((len(words), l_embeddings[lang].vectors.shape[-1]))
extraction[source_id] = l_embeddings[lang].vectors[target_id]

View File

@ -19,6 +19,7 @@ from dataManager.torchDataset import MultilingualDatasetTorch
transformers.logging.set_verbosity_error()
# TODO should pass also attention_mask to transformer model!
class MT5ForSequenceClassification(nn.Module):
def __init__(self, model_name, num_labels, output_hidden_states):
@ -115,7 +116,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
)
else:
# model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
# model_name = "hf_models/mbert-rai-fewshot-second/checkpoint-19000" # TODO hardcoded to pre-traiend mbert
# model_name = "hf_models/mbert-sentiment/checkpoint-1150" # TODO hardcoded to pre-traiend mbert
model_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
@ -229,6 +232,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
self.model.eval()
with torch.no_grad():
# TODO should pass also attention_mask !
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
if isinstance(self.model, MT5ForSequenceClassification):
@ -283,4 +287,4 @@ class TextualTransformerGen(ViewGen, TransformerGen):
def get_config(self):
c = super().get_config()
return {"name": "textual-trasnformer VGF", "textual_trf": c}
return {"name": "textual-transformer VGF", "textual_trf": c, "simple_id": "t"}

View File

@ -67,4 +67,4 @@ class VanillaFunGen(ViewGen):
return self
def get_config(self):
return {"name": "Vanilla Funnelling VGF"}
return {"name": "Vanilla Funnelling VGF", "simple_id": "p"}

View File

@ -38,6 +38,7 @@ class WceGen(ViewGen):
"name": "Word-Class Embeddings VGF",
"n_jobs": self.n_jobs,
"sif": self.sif,
"simple_id": "w"
}
def save_vgf(self, model_id):

View File

@ -11,14 +11,21 @@ from gfun.vgfs.commons import Trainer
from datasets import load_dataset, DatasetDict
from transformers import Trainer
from pprint import pprint
import transformers
import evaluate
import pandas as pd
transformers.logging.set_verbosity_error()
IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"]
RAI_D_COLUMNS = ["id", "lang", "provider", "date", "title", "text", "label"]
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] # "lang"
WEBIS_D_COLUMNS = ['Unnamed: 0', 'asin', 'category', 'original_rating', 'label', 'title', 'text', 'summary'] # "lang"
MAX_LEN = 128
# DATASET_NAME = "rai"
# DATASET_NAME = "rai-multilingual-2000"
# DATASET_NAME = "webis-cls"
def init_callbacks(patience=-1, nosave=False):
@ -30,8 +37,9 @@ def init_callbacks(patience=-1, nosave=False):
def init_model(model_name, nlabels):
if model_name == "mbert":
hf_name = "bert-base-multilingual-cased"
# hf_name = "mbert-rai-multi-2000/checkpoint-1500"
# hf_name = "bert-base-multilingual-cased"
hf_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
# hf_name = "hf_models/mbert-rai-fewshot-second/checkpoint-9000"
elif model_name == "xlm-roberta":
hf_name = "xlm-roberta-base"
else:
@ -47,42 +55,38 @@ def main(args):
data = load_dataset(
"csv",
data_files = {
"train": expanduser("~/datasets/rai/csv/train-split-rai.csv"),
"test": expanduser("~/datasets/rai/csv/test-split-rai.csv")
"train": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/train.balanced.csv"),
"test": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/test.balanced.csv")
# "train": expanduser(f"~/datasets/rai/csv/train-{DATASET_NAME}.csv"),
# "test": expanduser(f"~/datasets/rai/csv/test-{DATASET_NAME}.csv")
# "train": expanduser(f"~/datasets/rai/csv/train.small.csv"),
# "test": expanduser(f"~/datasets/rai/csv/test.small.csv")
}
)
def process_sample_iwslt(sample):
inputs = sample["text"]
ratings = [r - 1 for r in sample["rating"]]
targets = torch.zeros((len(inputs), 3), dtype=float)
lang_mapper = {
lang: lang_id for lang_id, lang in enumerate(set(sample["lang"]))
}
lang_ids = [lang_mapper[l] for l in sample["lang"]]
for i, r in enumerate(ratings):
targets[i][r - 1] = 1
model_inputs = tokenizer(inputs, max_length=128, truncation=True)
model_inputs["labels"] = targets
model_inputs["lang_ids"] = torch.tensor(lang_ids)
return model_inputs
def process_sample_rai(sample):
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
labels = sample["label"]
model_inputs = tokenizer(inputs, max_length=512, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
model_inputs["labels"] = labels
return model_inputs
def process_sample_webis(sample):
inputs = sample["text"]
labels = sample["label"]
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
model_inputs["labels"] = labels
return model_inputs
data = data.map(
process_sample_rai,
# process_sample_rai,
process_sample_webis,
batched=True,
num_proc=4,
load_from_cache_file=True,
remove_columns=RAI_D_COLUMNS,
# remove_columns=RAI_D_COLUMNS,
remove_columns=WEBIS_D_COLUMNS,
)
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
data.set_format("torch")
@ -103,7 +107,8 @@ def main(args):
recall_metric = evaluate.load("recall")
training_args = TrainingArguments(
output_dir=f"hf_models/{args.model}-rai-fewshot",
# output_dir=f"hf_models/{args.model}-rai",
output_dir=f"hf_models/{args.model}-sentiment-balanced",
do_train=True,
evaluation_strategy="steps",
per_device_train_batch_size=args.batch,
@ -115,7 +120,7 @@ def main(args):
max_grad_norm=5.0,
num_train_epochs=args.epochs,
lr_scheduler_type=args.scheduler,
warmup_ratio=0.1,
warmup_ratio=0.01,
logging_strategy="steps",
logging_first_step=True,
logging_steps=args.steplog,
@ -125,7 +130,8 @@ def main(args):
save_strategy="no" if args.nosave else "steps",
save_total_limit=2,
eval_steps=args.stepeval,
run_name=f"{args.model}-rai-stratified",
# run_name=f"{args.model}-rai-stratified",
run_name=f"{args.model}-sentiment",
disable_tqdm=False,
log_level="warning",
report_to=["wandb"] if args.wandb else "none",
@ -177,22 +183,32 @@ def main(args):
callbacks=callbacks,
)
print("- Training:")
trainer.train()
if not args.onlytest:
print("- Training:")
trainer.train()
print("- Testing:")
test_results = trainer.evaluate(eval_dataset=data["test"], metric_key_prefix="test")
print(test_results)
test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test")
pprint(test_results.metrics)
save_preds(data["test"], test_results.predictions)
exit()
def save_preds(dataset, predictions):
df = pd.DataFrame()
df["langs"] = dataset["lang"]
df["labels"] = dataset["labels"]
df["preds"] = predictions.argmax(axis=1)
df.to_csv("results/lang-specific.mbert.webis.csv", index=False)
return
if __name__ == "__main__":
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, metavar="", default="mbert")
parser.add_argument("--nlabels", type=int, metavar="", default=28)
parser.add_argument("--lr", type=float, metavar="", default=1e-4, help="Set learning rate",)
parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",)
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
@ -203,7 +219,7 @@ if __name__ == "__main__":
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
# parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
args = parser.parse_args()
main(args)

27
main.py
View File

@ -1,10 +1,13 @@
from argparse import ArgumentParser
from time import time
from csvlogger import CsvLogger
from dataManager.utils import get_dataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
import pandas as pd
"""
TODO:
- General:
@ -31,7 +34,7 @@ def get_config_name(args):
def main(args):
dataset = get_dataset(args.dataset, args)
lX, lY = dataset.training()
lX, lY = dataset.training(merge_validation=True)
lX_te, lY_te = dataset.test()
tinit = time()
@ -141,6 +144,26 @@ def main(args):
if args.wandb:
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
config["gFun"]["timing"] = f"{timeval - tinit:.2f}"
csvlogger = CsvLogger(outfile="results/random.log.csv").log_lang_results(lang_metrics_gfun, config)
save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"])
def save_preds(preds, targets, config="unk", dataset="unk"):
df = pd.DataFrame()
langs = sorted(preds.keys())
_preds = []
_targets = []
_langs = []
for lang in langs:
_preds.extend(preds[lang].argmax(axis=1).tolist())
_targets.extend(targets[lang].argmax(axis=1).tolist())
_langs.extend([lang for i in range(len(preds[lang]))])
df["langs"] = _langs
df["labels"] = _targets
df["preds"] = _preds
df.to_csv(f"results/lang-specific.gfun.{config}.{dataset}.csv", index=False)
if __name__ == "__main__":
parser = ArgumentParser()
@ -148,6 +171,8 @@ if __name__ == "__main__":
parser.add_argument("--meta", action="store_true")
parser.add_argument("--nosave", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--tr_langs", nargs="+", default=None)
parser.add_argument("--te_langs", nargs="+", default=None)
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
parser.add_argument("--domains", type=str, default="all")

137
run-rai.sh Normal file
View File

@ -0,0 +1,137 @@
#!bin/bash
njobs=-1
clf=singlelabel
patience=5
eval_every=5
text_len=512
text_lr=1e-4
bsize=2
txt_model=mbert
dataset=rai
# config="-p"
# echo "[Running gFun config: $config]"
# python main.py $config \
# -d $dataset\
# --nosave \
# --n_jobs $njobs \
# --clf_type $clf \
# --patience $patience \
# --evaluate_step $eval_every \
# --batch_size $bsize \
# --max_length $text_len \
# --textual_lr $text_lr \
# --textual_trf_name $txt_model\
config="-m"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-w"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-t"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pm"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pmw"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pt"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pmt"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pmwt"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\

View File

@ -1,22 +1,20 @@
#!bin/bash
config="-m"
echo "[Running gFun config: $config]"
epochs=100
njobs=-1
clf=singlelabel
patience=5
eval_every=5
text_len=256
text_len=512
text_lr=1e-4
bsize=64
bsize=2
txt_model=mbert
dataset=webis
config="-p"
echo "[Running gFun config: $config]"
python main.py $config \
-d webis \
--epochs $epochs \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
@ -24,5 +22,116 @@ python main.py $config \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model \
--load_trained webis_pmwt_mean_230621
--textual_trf_name $txt_model\
config="-m"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-w"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-t"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pm"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pmw"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pt"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pmt"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\
config="-pmwt"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model\