Compare commits
13 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
c63c35269a | |
|
|
2800694672 | |
|
|
e8b6396366 | |
|
|
e3e6f061d8 | |
|
|
60171c1b5e | |
|
|
2554c58fac | |
|
|
9437ccc837 | |
|
|
de98926d00 | |
|
|
bef086ab50 | |
|
|
732ffbefb1 | |
|
|
9ce0001047 | |
|
|
b3b7c69263 | |
|
|
770e8e62be |
|
|
@ -183,3 +183,4 @@ logger/*
|
||||||
explore_data.ipynb
|
explore_data.ipynb
|
||||||
run.sh
|
run.sh
|
||||||
wandb
|
wandb
|
||||||
|
local_datasets
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
|
@ -8,13 +9,87 @@ import re
|
||||||
from dataManager.multilingualDataset import MultilingualDataset
|
from dataManager.multilingualDataset import MultilingualDataset
|
||||||
|
|
||||||
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
|
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
|
||||||
LANGS = ["de", "en", "fr", "jp"]
|
CLS_UNPROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-unprocessed/")
|
||||||
|
# LANGS = ["de", "en", "fr", "jp"]
|
||||||
|
LANGS = ["de", "en", "fr"]
|
||||||
DOMAINS = ["books", "dvd", "music"]
|
DOMAINS = ["books", "dvd", "music"]
|
||||||
|
|
||||||
regex = r":\d+"
|
regex = r":\d+"
|
||||||
subst = ""
|
subst = ""
|
||||||
|
|
||||||
|
|
||||||
|
def load_unprocessed_cls(reduce_target_space=False):
|
||||||
|
data = {}
|
||||||
|
data_tr = []
|
||||||
|
data_te = []
|
||||||
|
c_tr = 0
|
||||||
|
c_te = 0
|
||||||
|
for lang in LANGS:
|
||||||
|
data[lang] = {}
|
||||||
|
for domain in DOMAINS:
|
||||||
|
data[lang][domain] = {}
|
||||||
|
print(f"lang: {lang}, domain: {domain}")
|
||||||
|
for split in ["train", "test"]:
|
||||||
|
domain_data = []
|
||||||
|
fdir = os.path.join(
|
||||||
|
CLS_UNPROCESSED_DATA_DIR, lang, domain, f"{split}.review"
|
||||||
|
)
|
||||||
|
tree = ET.parse(fdir)
|
||||||
|
root = tree.getroot()
|
||||||
|
for child in root:
|
||||||
|
if reduce_target_space:
|
||||||
|
rating = np.zeros(3, dtype=int)
|
||||||
|
original_rating = int(float(child.find("rating").text))
|
||||||
|
if original_rating < 3:
|
||||||
|
new_rating = 1
|
||||||
|
elif original_rating > 3:
|
||||||
|
new_rating = 3
|
||||||
|
else:
|
||||||
|
new_rating = 2
|
||||||
|
rating[new_rating - 1] = 1
|
||||||
|
# rating = new_rating
|
||||||
|
else:
|
||||||
|
rating = np.zeros(5, dtype=int)
|
||||||
|
rating[int(float(child.find("rating").text)) - 1] = 1
|
||||||
|
# rating = new_rating
|
||||||
|
# if split == "train":
|
||||||
|
# target_data = data_tr
|
||||||
|
# current_count = len(target_data)
|
||||||
|
# c_tr = +1
|
||||||
|
# else:
|
||||||
|
# target_data = data_te
|
||||||
|
# current_count = len(target_data)
|
||||||
|
# c_te = +1
|
||||||
|
domain_data.append(
|
||||||
|
# target_data.append(
|
||||||
|
{
|
||||||
|
"asin": child.find("asin").text
|
||||||
|
if child.find("asin") is not None
|
||||||
|
else None,
|
||||||
|
# "category": child.find("category").text
|
||||||
|
# if child.find("category") is not None
|
||||||
|
# else None,
|
||||||
|
"category": domain,
|
||||||
|
# "rating": child.find("rating").text
|
||||||
|
# if child.find("rating") is not None
|
||||||
|
# else None,
|
||||||
|
"rating": rating,
|
||||||
|
"title": child.find("title").text
|
||||||
|
if child.find("title") is not None
|
||||||
|
else None,
|
||||||
|
"text": child.find("text").text
|
||||||
|
if child.find("text") is not None
|
||||||
|
else None,
|
||||||
|
"summary": child.find("summary").text
|
||||||
|
if child.find("summary") is not None
|
||||||
|
else None,
|
||||||
|
"lang": lang,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data[lang][domain].update({split: domain_data})
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def load_cls():
|
def load_cls():
|
||||||
data = {}
|
data = {}
|
||||||
for lang in LANGS:
|
for lang in LANGS:
|
||||||
|
|
@ -24,7 +99,7 @@ def load_cls():
|
||||||
train = (
|
train = (
|
||||||
open(
|
open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
|
CLS_UNPROCESSED_DATA_DIR, lang, domain, "train.processed"
|
||||||
),
|
),
|
||||||
"r",
|
"r",
|
||||||
)
|
)
|
||||||
|
|
@ -34,7 +109,7 @@ def load_cls():
|
||||||
test = (
|
test = (
|
||||||
open(
|
open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
|
CLS_UNPROCESSED_DATA_DIR, lang, domain, "test.processed"
|
||||||
),
|
),
|
||||||
"r",
|
"r",
|
||||||
)
|
)
|
||||||
|
|
@ -59,18 +134,33 @@ def process_data(line):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
|
print(f"datapath: {CLS_UNPROCESSED_DATA_DIR}")
|
||||||
data = load_cls()
|
# data = load_cls()
|
||||||
multilingualDataset = MultilingualDataset(dataset_name="cls")
|
data = load_unprocessed_cls(reduce_target_space=True)
|
||||||
for lang in LANGS:
|
multilingualDataset = MultilingualDataset(dataset_name="webis-cls-unprocessed")
|
||||||
# TODO: just using book domain atm
|
|
||||||
Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
|
||||||
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
|
||||||
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
|
|
||||||
|
|
||||||
Xte = [text[0] for text in data[lang]["books"]["test"]]
|
for lang in LANGS:
|
||||||
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
# Xtr = [text["summary"] for text in data[lang]["books"]["train"]]
|
||||||
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
|
Xtr = [text["text"] for text in data[lang]["books"]["train"]]
|
||||||
|
Xtr += [text["text"] for text in data[lang]["dvd"]["train"]]
|
||||||
|
Xtr += [text["text"] for text in data[lang]["music"]["train"]]
|
||||||
|
|
||||||
|
Ytr =[text["rating"] for text in data[lang]["books"]["train"]]
|
||||||
|
Ytr += [text["rating"] for text in data[lang]["dvd"]["train"]]
|
||||||
|
Ytr += [text["rating"] for text in data[lang]["music"]["train"]]
|
||||||
|
|
||||||
|
Ytr = np.vstack(Ytr)
|
||||||
|
|
||||||
|
Xte = [text["text"] for text in data[lang]["books"]["test"]]
|
||||||
|
Xte += [text["text"] for text in data[lang]["dvd"]["test"]]
|
||||||
|
Xte += [text["text"] for text in data[lang]["music"]["test"]]
|
||||||
|
|
||||||
|
|
||||||
|
Yte = [text["rating"] for text in data[lang]["books"]["test"]]
|
||||||
|
Yte += [text["rating"] for text in data[lang]["dvd"]["test"]]
|
||||||
|
Yte += [text["rating"] for text in data[lang]["music"]["test"]]
|
||||||
|
|
||||||
|
Yte = np.vstack(Yte)
|
||||||
|
|
||||||
multilingualDataset.add(
|
multilingualDataset.add(
|
||||||
lang=lang,
|
lang=lang,
|
||||||
|
|
@ -82,5 +172,7 @@ if __name__ == "__main__":
|
||||||
te_ids=None,
|
te_ids=None,
|
||||||
)
|
)
|
||||||
multilingualDataset.save(
|
multilingualDataset.save(
|
||||||
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
os.path.expanduser(
|
||||||
|
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -62,14 +62,29 @@ class gFunDataset:
|
||||||
)
|
)
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
elif "cls" in self.dataset_dir.lower():
|
# WEBIS-CLS (processed)
|
||||||
print(f"- Loading CLS dataset from {self.dataset_dir}")
|
elif (
|
||||||
|
"cls" in self.dataset_dir.lower()
|
||||||
|
and "unprocessed" not in self.dataset_dir.lower()
|
||||||
|
):
|
||||||
|
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
|
||||||
self.dataset_name = "cls"
|
self.dataset_name = "cls"
|
||||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||||
self.dataset_name, self.dataset_dir, self.nrows
|
self.dataset_name, self.dataset_dir, self.nrows
|
||||||
)
|
)
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
|
# WEBIS-CLS (unprocessed)
|
||||||
|
elif (
|
||||||
|
"cls" in self.dataset_dir.lower()
|
||||||
|
and "unprocessed" in self.dataset_dir.lower()
|
||||||
|
):
|
||||||
|
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
|
||||||
|
self.dataset_name = "cls"
|
||||||
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||||
|
self.dataset_name, self.dataset_dir, self.nrows
|
||||||
|
)
|
||||||
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
self.show_dimension()
|
self.show_dimension()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ def get_dataset(dataset_name, args):
|
||||||
"rcv1-2",
|
"rcv1-2",
|
||||||
"glami",
|
"glami",
|
||||||
"cls",
|
"cls",
|
||||||
|
"webis",
|
||||||
], "dataset not supported"
|
], "dataset not supported"
|
||||||
|
|
||||||
RCV_DATAPATH = expanduser(
|
RCV_DATAPATH = expanduser(
|
||||||
|
|
@ -37,6 +38,10 @@ def get_dataset(dataset_name, args):
|
||||||
|
|
||||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||||
|
|
||||||
|
WEBIS_CLS = expanduser(
|
||||||
|
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
||||||
|
)
|
||||||
|
|
||||||
if dataset_name == "multinews":
|
if dataset_name == "multinews":
|
||||||
# TODO: convert to gFunDataset
|
# TODO: convert to gFunDataset
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
@ -91,6 +96,15 @@ def get_dataset(dataset_name, args):
|
||||||
is_multilabel=False,
|
is_multilabel=False,
|
||||||
nrows=args.nrows,
|
nrows=args.nrows,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif dataset_name == "webis":
|
||||||
|
dataset = gFunDataset(
|
||||||
|
dataset_dir=WEBIS_CLS,
|
||||||
|
is_textual=True,
|
||||||
|
is_visual=False,
|
||||||
|
is_multilabel=False,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from evaluation.metrics import *
|
# from evaluation.metrics import *
|
||||||
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
|
import numpy as np
|
||||||
|
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
|
||||||
|
|
||||||
def evaluation_metrics(y, y_, clf_type):
|
def evaluation_metrics(y, y_, clf_type):
|
||||||
|
|
@ -13,13 +14,17 @@ def evaluation_metrics(y, y_, clf_type):
|
||||||
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
|
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
|
||||||
f1_score(y, y_, average="macro", zero_division=1),
|
f1_score(y, y_, average="macro", zero_division=1),
|
||||||
f1_score(y, y_, average="micro"),
|
f1_score(y, y_, average="micro"),
|
||||||
|
precision_score(y, y_, zero_division=1, average="macro"),
|
||||||
|
recall_score(y, y_, zero_division=1, average="macro"),
|
||||||
)
|
)
|
||||||
elif clf_type == "multilabel":
|
elif clf_type == "multilabel":
|
||||||
return (
|
return (
|
||||||
macroF1(y, y_),
|
f1_score(y, y_, average="macro", zero_division=1),
|
||||||
microF1(y, y_),
|
f1_score(y, y_, average="micro"),
|
||||||
macroK(y, y_),
|
0,
|
||||||
microK(y, y_),
|
0,
|
||||||
|
# macroK(y, y_),
|
||||||
|
# microK(y, y_),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
|
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
|
||||||
|
|
@ -48,8 +53,10 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
||||||
|
|
||||||
if clf_type == "multilabel":
|
if clf_type == "multilabel":
|
||||||
for lang in l_eval.keys():
|
for lang in l_eval.keys():
|
||||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
# macrof1, microf1, macrok, microk = l_eval[lang]
|
||||||
metrics.append([macrof1, microf1, macrok, microk])
|
# metrics.append([macrof1, microf1, macrok, microk])
|
||||||
|
macrof1, microf1, precision, recall = l_eval[lang]
|
||||||
|
metrics.append([macrof1, microf1, precision, recall])
|
||||||
if phase != "validation":
|
if phase != "validation":
|
||||||
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||||
averages = np.mean(np.array(metrics), axis=0)
|
averages = np.mean(np.array(metrics), axis=0)
|
||||||
|
|
@ -69,12 +76,15 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
||||||
# "acc10", # "accuracy-at-10",
|
# "acc10", # "accuracy-at-10",
|
||||||
"MF1", # "macro-F1",
|
"MF1", # "macro-F1",
|
||||||
"mF1", # "micro-F1",
|
"mF1", # "micro-F1",
|
||||||
|
"precision",
|
||||||
|
"recall"
|
||||||
]
|
]
|
||||||
for lang in l_eval.keys():
|
for lang in l_eval.keys():
|
||||||
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
|
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
|
||||||
acc, macrof1, microf1 = l_eval[lang]
|
acc, macrof1, microf1, precision, recall= l_eval[lang]
|
||||||
# metrics.append([acc, top5, top10, macrof1, microf1])
|
# metrics.append([acc, top5, top10, macrof1, microf1])
|
||||||
metrics.append([acc, macrof1, microf1])
|
# metrics.append([acc, macrof1, microf1])
|
||||||
|
metrics.append([acc, macrof1, microf1, precision, recall])
|
||||||
|
|
||||||
for m, v in zip(_metrics, l_eval[lang]):
|
for m, v in zip(_metrics, l_eval[lang]):
|
||||||
lang_metrics[m][lang] = v
|
lang_metrics[m][lang] = v
|
||||||
|
|
@ -82,7 +92,8 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
||||||
if phase != "validation":
|
if phase != "validation":
|
||||||
print(
|
print(
|
||||||
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||||
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
# f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||||
|
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f} pr = {precision:.3f} re = {recall:.3f}"
|
||||||
)
|
)
|
||||||
averages = np.mean(np.array(metrics), axis=0)
|
averages = np.mean(np.array(metrics), axis=0)
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,16 @@ class GeneralizedFunnelling:
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
attn_stacking_type=attn_stacking,
|
attn_stacking_type=attn_stacking,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._model_id = get_unique_id(
|
||||||
|
self.dataset_name,
|
||||||
|
self.posteriors_vgf,
|
||||||
|
self.multilingual_vgf,
|
||||||
|
self.wce_vgf,
|
||||||
|
self.textual_trf_vgf,
|
||||||
|
self.visual_trf_vgf,
|
||||||
|
self.aggfunc,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if self.posteriors_vgf:
|
if self.posteriors_vgf:
|
||||||
|
|
@ -317,7 +327,7 @@ class GeneralizedFunnelling:
|
||||||
|
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
vgf_config = vgf.get_config()
|
vgf_config = vgf.get_config()
|
||||||
c.update(vgf_config)
|
c.update({vgf_config["name"]: vgf_config})
|
||||||
|
|
||||||
gfun_config = {
|
gfun_config = {
|
||||||
"id": self._model_id,
|
"id": self._model_id,
|
||||||
|
|
@ -372,6 +382,7 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
print(f"- loaded trained VanillaFun VGF")
|
||||||
if self.multilingual_vgf:
|
if self.multilingual_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
|
|
@ -380,6 +391,7 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
print(f"- loaded trained Multilingual VGF")
|
||||||
if self.wce_vgf:
|
if self.wce_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
|
|
@ -388,20 +400,38 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
print(f"- loaded trained WCE VGF")
|
||||||
if self.textual_trf_vgf:
|
if self.textual_trf_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
"models",
|
||||||
|
"vgfs",
|
||||||
|
"textual_transformer",
|
||||||
|
f"textualTransformerGen_{model_id}.pkl",
|
||||||
),
|
),
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
print(f"- loaded trained Textual Transformer VGF")
|
||||||
|
if self.visual_trf_vgf:
|
||||||
|
with open(
|
||||||
|
os.path.join(
|
||||||
|
"models",
|
||||||
|
"vgfs",
|
||||||
|
"visual_transformer",
|
||||||
|
f"visualTransformerGen_{model_id}.pkl",
|
||||||
|
),
|
||||||
|
"rb",
|
||||||
|
print(f"- loaded trained Visual Transformer VGF"),
|
||||||
|
) as vgf:
|
||||||
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
|
||||||
if load_meta:
|
if load_meta:
|
||||||
with open(
|
with open(
|
||||||
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
||||||
) as f:
|
) as f:
|
||||||
metaclassifier = pickle.load(f)
|
metaclassifier = pickle.load(f)
|
||||||
|
print(f"- loaded trained metaclassifier")
|
||||||
else:
|
else:
|
||||||
metaclassifier = None
|
metaclassifier = None
|
||||||
return first_tier_learners, metaclassifier, vectorizer
|
return first_tier_learners, metaclassifier, vectorizer
|
||||||
|
|
|
||||||
|
|
@ -45,11 +45,12 @@ class MT5ForSequenceClassification(nn.Module):
|
||||||
|
|
||||||
def save_pretrained(self, checkpoint_dir):
|
def save_pretrained(self, checkpoint_dir):
|
||||||
torch.save(self.state_dict(), checkpoint_dir + ".pt")
|
torch.save(self.state_dict(), checkpoint_dir + ".pt")
|
||||||
return
|
return self
|
||||||
|
|
||||||
def from_pretrained(self, checkpoint_dir):
|
def from_pretrained(self, checkpoint_dir):
|
||||||
checkpoint_dir += ".pt"
|
checkpoint_dir += ".pt"
|
||||||
return self.load_state_dict(torch.load(checkpoint_dir))
|
self.load_state_dict(torch.load(checkpoint_dir))
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
|
@ -113,6 +114,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
model_name, num_labels=num_labels, output_hidden_states=True
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
|
||||||
return AutoModelForSequenceClassification.from_pretrained(
|
return AutoModelForSequenceClassification.from_pretrained(
|
||||||
model_name, num_labels=num_labels, output_hidden_states=True
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
)
|
)
|
||||||
|
|
@ -144,58 +146,60 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
self.model_name, num_labels=self.num_labels
|
self.model_name, num_labels=self.num_labels
|
||||||
)
|
)
|
||||||
|
|
||||||
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
self.model.to("cuda")
|
||||||
lX, lY, split=0.2, seed=42, modality="text"
|
|
||||||
)
|
|
||||||
|
|
||||||
tra_dataloader = self.build_dataloader(
|
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||||
tr_lX,
|
# lX, lY, split=0.2, seed=42, modality="text"
|
||||||
tr_lY,
|
# )
|
||||||
processor_fn=self._tokenize,
|
#
|
||||||
torchDataset=MultilingualDatasetTorch,
|
# tra_dataloader = self.build_dataloader(
|
||||||
batch_size=self.batch_size,
|
# tr_lX,
|
||||||
split="train",
|
# tr_lY,
|
||||||
shuffle=True,
|
# processor_fn=self._tokenize,
|
||||||
)
|
# torchDataset=MultilingualDatasetTorch,
|
||||||
|
# batch_size=self.batch_size,
|
||||||
val_dataloader = self.build_dataloader(
|
# split="train",
|
||||||
val_lX,
|
# shuffle=True,
|
||||||
val_lY,
|
# )
|
||||||
processor_fn=self._tokenize,
|
#
|
||||||
torchDataset=MultilingualDatasetTorch,
|
# val_dataloader = self.build_dataloader(
|
||||||
batch_size=self.batch_size_eval,
|
# val_lX,
|
||||||
split="val",
|
# val_lY,
|
||||||
shuffle=False,
|
# processor_fn=self._tokenize,
|
||||||
)
|
# torchDataset=MultilingualDatasetTorch,
|
||||||
|
# batch_size=self.batch_size_eval,
|
||||||
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
# split="val",
|
||||||
|
# shuffle=False,
|
||||||
trainer = Trainer(
|
# )
|
||||||
model=self.model,
|
#
|
||||||
optimizer_name="adamW",
|
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||||
lr=self.lr,
|
#
|
||||||
device=self.device,
|
# trainer = Trainer(
|
||||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
# model=self.model,
|
||||||
print_steps=self.print_steps,
|
# optimizer_name="adamW",
|
||||||
evaluate_step=self.evaluate_step,
|
# lr=self.lr,
|
||||||
patience=self.patience,
|
# device=self.device,
|
||||||
experiment_name=experiment_name,
|
# loss_fn=torch.nn.CrossEntropyLoss(),
|
||||||
checkpoint_path=os.path.join(
|
# print_steps=self.print_steps,
|
||||||
"models",
|
# evaluate_step=self.evaluate_step,
|
||||||
"vgfs",
|
# patience=self.patience,
|
||||||
"transformer",
|
# experiment_name=experiment_name,
|
||||||
self._format_model_name(self.model_name),
|
# checkpoint_path=os.path.join(
|
||||||
),
|
# "models",
|
||||||
vgf_name="textual_trf",
|
# "vgfs",
|
||||||
classification_type=self.clf_type,
|
# "trained_transformer",
|
||||||
n_jobs=self.n_jobs,
|
# self._format_model_name(self.model_name),
|
||||||
scheduler_name=self.scheduler,
|
# ),
|
||||||
)
|
# vgf_name="textual_trf",
|
||||||
trainer.train(
|
# classification_type=self.clf_type,
|
||||||
train_dataloader=tra_dataloader,
|
# n_jobs=self.n_jobs,
|
||||||
eval_dataloader=val_dataloader,
|
# scheduler_name=self.scheduler,
|
||||||
epochs=self.epochs,
|
# )
|
||||||
)
|
# trainer.train(
|
||||||
|
# train_dataloader=tra_dataloader,
|
||||||
|
# eval_dataloader=val_dataloader,
|
||||||
|
# epochs=self.epochs,
|
||||||
|
# )
|
||||||
|
|
||||||
if self.probabilistic:
|
if self.probabilistic:
|
||||||
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
||||||
|
|
@ -224,7 +228,6 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for input_ids, lang in dataloader:
|
for input_ids, lang in dataloader:
|
||||||
input_ids = input_ids.to(self.device)
|
input_ids = input_ids.to(self.device)
|
||||||
# TODO: check this
|
|
||||||
if isinstance(self.model, MT5ForSequenceClassification):
|
if isinstance(self.model, MT5ForSequenceClassification):
|
||||||
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
||||||
else:
|
else:
|
||||||
|
|
@ -277,4 +280,4 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
c = super().get_config()
|
c = super().get_config()
|
||||||
return {"textual_trf": c}
|
return {"name": "textual-trasnformer VGF", "textual_trf": c}
|
||||||
|
|
|
||||||
|
|
@ -65,3 +65,6 @@ class VanillaFunGen(ViewGen):
|
||||||
with open(_path, "wb") as f:
|
with open(_path, "wb") as f:
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {"name": "Vanilla Funnelling VGF"}
|
||||||
|
|
|
||||||
|
|
@ -186,4 +186,4 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {"visual_trf": super().get_config()}
|
return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,191 @@
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoTokenizer,
|
||||||
|
DataCollatorWithPadding,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
from gfun.vgfs.commons import Trainer
|
||||||
|
from datasets import load_dataset, DatasetDict
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
import evaluate
|
||||||
|
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
|
def init_callbacks(patience=-1, nosave=False):
|
||||||
|
callbacks = []
|
||||||
|
if patience != -1 and not nosave:
|
||||||
|
callbacks.append(transformers.EarlyStoppingCallback(early_stopping_patience=patience))
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(model_name):
|
||||||
|
if model_name == "mbert":
|
||||||
|
hf_name = "bert-base-multilingual-cased"
|
||||||
|
elif model_name == "xlm-roberta":
|
||||||
|
hf_name = "xlm-roberta-base"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(hf_name)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=3)
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
tokenizer, model = init_model(args.model)
|
||||||
|
|
||||||
|
data = load_dataset(
|
||||||
|
"json",
|
||||||
|
data_files={
|
||||||
|
"train": "local_datasets/webis-cls/all-domains/train.json",
|
||||||
|
"test": "local_datasets/webis-cls/all-domains/test.json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_sample(sample):
|
||||||
|
inputs = sample["text"]
|
||||||
|
ratings = [r - 1 for r in sample["rating"]]
|
||||||
|
targets = torch.zeros((len(inputs), 3), dtype=float)
|
||||||
|
lang_mapper = {
|
||||||
|
lang: lang_id for lang_id, lang in enumerate(set(sample["lang"]))
|
||||||
|
}
|
||||||
|
lang_ids = [lang_mapper[l] for l in sample["lang"]]
|
||||||
|
for i, r in enumerate(ratings):
|
||||||
|
targets[i][r - 1] = 1
|
||||||
|
|
||||||
|
model_inputs = tokenizer(inputs, max_length=512, truncation=True)
|
||||||
|
model_inputs["labels"] = targets
|
||||||
|
model_inputs["lang_ids"] = torch.tensor(lang_ids)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
data = data.map(
|
||||||
|
process_sample,
|
||||||
|
batched=True,
|
||||||
|
num_proc=4,
|
||||||
|
load_from_cache_file=True,
|
||||||
|
remove_columns=["text", "category", "rating", "summary", "title"],
|
||||||
|
)
|
||||||
|
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
||||||
|
data.set_format("torch")
|
||||||
|
data = DatasetDict(
|
||||||
|
{
|
||||||
|
"train": train_val_splits["train"],
|
||||||
|
"validation": train_val_splits["test"],
|
||||||
|
"test": data["test"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
callbacks = init_callbacks(args.patience, args.nosave)
|
||||||
|
|
||||||
|
f1_metric = evaluate.load("f1")
|
||||||
|
accuracy_metric = evaluate.load("accuracy")
|
||||||
|
precision_metric = evaluate.load("precision")
|
||||||
|
recall_metric = evaluate.load("recall")
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=f"{args.model}-sentiment",
|
||||||
|
do_train=True,
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
per_device_train_batch_size=args.batch,
|
||||||
|
per_device_eval_batch_size=args.batch,
|
||||||
|
gradient_accumulation_steps=args.gradacc,
|
||||||
|
eval_accumulation_steps=10,
|
||||||
|
learning_rate=args.lr,
|
||||||
|
weight_decay=0.1,
|
||||||
|
max_grad_norm=5.0,
|
||||||
|
num_train_epochs=args.epochs,
|
||||||
|
lr_scheduler_type=args.scheduler,
|
||||||
|
warmup_steps=1000,
|
||||||
|
logging_strategy="steps",
|
||||||
|
logging_first_step=True,
|
||||||
|
logging_steps=args.steplog,
|
||||||
|
seed=42,
|
||||||
|
fp16=args.fp16,
|
||||||
|
load_best_model_at_end=False if args.nosave else True,
|
||||||
|
save_strategy="no" if args.nosave else "steps",
|
||||||
|
save_total_limit=3,
|
||||||
|
eval_steps=args.stepeval,
|
||||||
|
run_name=f"{args.model}-sentiment-run",
|
||||||
|
disable_tqdm=False,
|
||||||
|
log_level="warning",
|
||||||
|
report_to=["wandb"] if args.wandb else "none",
|
||||||
|
optim="adamw_torch",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics(eval_preds):
|
||||||
|
preds = eval_preds.predictions.argmax(-1)
|
||||||
|
targets = eval_preds.label_ids.argmax(-1)
|
||||||
|
setting = "macro"
|
||||||
|
f1_score_macro = f1_metric.compute(
|
||||||
|
predictions=preds, references=targets, average="macro"
|
||||||
|
)
|
||||||
|
f1_score_micro = f1_metric.compute(
|
||||||
|
predictions=preds, references=targets, average="micro"
|
||||||
|
)
|
||||||
|
accuracy_score = accuracy_metric.compute(predictions=preds, references=targets)
|
||||||
|
precision_score = precision_metric.compute(
|
||||||
|
predictions=preds, references=targets, average=setting, zero_division=1
|
||||||
|
)
|
||||||
|
recall_score = recall_metric.compute(
|
||||||
|
predictions=preds, references=targets, average=setting, zero_division=1
|
||||||
|
)
|
||||||
|
results = {
|
||||||
|
"macro_f1score": f1_score_macro["f1"],
|
||||||
|
"micro_f1score": f1_score_micro["f1"],
|
||||||
|
"accuracy": accuracy_score["accuracy"],
|
||||||
|
"precision": precision_score["precision"],
|
||||||
|
"recall": recall_score["recall"],
|
||||||
|
}
|
||||||
|
results = {k: round(v, 4) for k, v in results.items()}
|
||||||
|
return results
|
||||||
|
|
||||||
|
if args.wandb:
|
||||||
|
import wandb
|
||||||
|
wandb.init(entity="andreapdr", project=f"gfun-senti-hf", name="mbert-sent", config=vars(args))
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=data["train"],
|
||||||
|
eval_dataset=data["validation"],
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("- Training:")
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
print("- Testing:")
|
||||||
|
test_results = trainer.evaluate(eval_dataset=data["test"])
|
||||||
|
print(test_results)
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||||
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--model", type=str, metavar="", default="mbert")
|
||||||
|
parser.add_argument("--lr", type=float, metavar="", default=1e-5, help="Set learning rate",)
|
||||||
|
parser.add_argument("--scheduler", type=str, metavar="", default="linear", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
||||||
|
parser.add_argument("--batch", type=int, metavar="", default=16, help="Set batch size")
|
||||||
|
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
||||||
|
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
|
||||||
|
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
|
||||||
|
parser.add_argument("--steplog", type=int, metavar="", default=100, help="Log training every n steps")
|
||||||
|
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
|
||||||
|
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
|
||||||
|
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
|
||||||
|
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
|
||||||
|
# parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
|
||||||
|
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
13
main.py
13
main.py
|
|
@ -1,8 +1,5 @@
|
||||||
import os
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
@ -12,16 +9,9 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- Transformers VGFs:
|
|
||||||
- scheduler with warmup and cosine
|
|
||||||
- freeze params method
|
|
||||||
- General:
|
- General:
|
||||||
[!] zero-shot setup
|
[!] zero-shot setup
|
||||||
- CLS dataset is loading only "books" domain data
|
- CLS dataset is loading only "books" domain data
|
||||||
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
|
|
||||||
- Attention Aggregator:
|
|
||||||
- experiment with weight init of Attention-aggregator
|
|
||||||
- FFNN posterior-probabilities' dependent
|
|
||||||
- Docs:
|
- Docs:
|
||||||
- add documentations sphinx
|
- add documentations sphinx
|
||||||
"""
|
"""
|
||||||
|
|
@ -150,7 +140,6 @@ def main(args):
|
||||||
wandb.log(gfun_res)
|
wandb.log(gfun_res)
|
||||||
|
|
||||||
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||||
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -178,7 +167,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--features", action="store_false")
|
parser.add_argument("--features", action="store_false")
|
||||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||||
# transformer parameters ---------------
|
# transformer parameters ---------------
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=5)
|
||||||
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
#!bin/bash
|
||||||
|
|
||||||
|
config="-m"
|
||||||
|
|
||||||
|
echo "[Running gFun config: $config]"
|
||||||
|
|
||||||
|
epochs=100
|
||||||
|
njobs=-1
|
||||||
|
clf=singlelabel
|
||||||
|
patience=5
|
||||||
|
eval_every=5
|
||||||
|
text_len=256
|
||||||
|
text_lr=1e-4
|
||||||
|
bsize=64
|
||||||
|
txt_model=mbert
|
||||||
|
|
||||||
|
python main.py $config \
|
||||||
|
-d webis \
|
||||||
|
--epochs $epochs \
|
||||||
|
--n_jobs $njobs \
|
||||||
|
--clf_type $clf \
|
||||||
|
--patience $patience \
|
||||||
|
--evaluate_step $eval_every \
|
||||||
|
--batch_size $bsize \
|
||||||
|
--max_length $text_len \
|
||||||
|
--textual_lr $text_lr \
|
||||||
|
--textual_trf_name $txt_model \
|
||||||
|
--load_trained webis_pmwt_mean_230621
|
||||||
Loading…
Reference in New Issue