branching for rai
This commit is contained in:
parent
fbd740fabd
commit
234b6031b1
|
|
@ -5,12 +5,6 @@ from sklearn.metrics import mean_absolute_error
|
|||
|
||||
from os.path import join
|
||||
|
||||
"""
|
||||
MEA and classification is meaningful only in "ordinal" tasks e.g., sentiment classification.
|
||||
Otherwise the distance between the categories has no semantics!
|
||||
|
||||
- NB: we want to get the macro-averaged class specific MAE!
|
||||
"""
|
||||
|
||||
def main():
|
||||
# SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"]
|
||||
|
|
@ -24,11 +18,9 @@ def main():
|
|||
print(df)
|
||||
|
||||
|
||||
def evalaute(setting):
|
||||
def evalaute(setting, result_file="preds.gfun.csv"):
|
||||
result_dir = "results"
|
||||
# result_file = f"lang-specific.gfun.{setting}.webis.csv"
|
||||
result_file = f"lang-specific.mbert.webis.csv"
|
||||
# print(f"- reading from: {result_file}")
|
||||
print(f"- reading from: {result_file}")
|
||||
df = pd.read_csv(join(result_dir, result_file))
|
||||
langs = df.langs.unique()
|
||||
res = []
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class CsvLogger:
|
|||
# os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True)
|
||||
# return
|
||||
|
||||
def log_lang_results(self, results: dict, config="gfun-lello"):
|
||||
def log_lang_results(self, results: dict, config="gfun-default", notes=None):
|
||||
df = pd.DataFrame.from_dict(results, orient="columns")
|
||||
df["config"] = config["gFun"]["simple_id"]
|
||||
df["aggfunc"] = config["gFun"]["aggfunc"]
|
||||
|
|
@ -22,6 +22,7 @@ class CsvLogger:
|
|||
df["id"] = config["gFun"]["id"]
|
||||
df["optimc"] = config["gFun"]["optimc"]
|
||||
df["timing"] = config["gFun"]["timing"]
|
||||
df["notes"] = notes
|
||||
with open(self.outfile, 'a') as f:
|
||||
df.to_csv(f, mode='a', header=f.tell()==0)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,17 @@ from dataManager.multilingualDataset import MultilingualDataset
|
|||
|
||||
|
||||
class SimpleGfunDataset:
|
||||
def __init__(self, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name=None,
|
||||
datadir="~/datasets/rai/csv/",
|
||||
textual=True,
|
||||
visual=False,
|
||||
multilabel=False,
|
||||
set_tr_langs=None,
|
||||
set_te_langs=None
|
||||
):
|
||||
self.name = dataset_name
|
||||
self.datadir = os.path.expanduser(datadir)
|
||||
self.textual = textual
|
||||
self.visual = visual
|
||||
|
|
@ -41,11 +51,15 @@ class SimpleGfunDataset:
|
|||
print(f"tr: {tr} - va: {va} - te: {te}")
|
||||
|
||||
def load_csv(self, set_tr_langs, set_te_langs):
|
||||
# _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
|
||||
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.balanced.csv")).sample(100, random_state=42)
|
||||
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) # TODO stratify on lang or label?
|
||||
# test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
|
||||
test = pd.read_csv(os.path.join(self.datadir, "test.balanced.csv")).sample(100, random_state=42)
|
||||
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
|
||||
try:
|
||||
stratified = "class"
|
||||
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label)
|
||||
except:
|
||||
stratified = "lang"
|
||||
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang)
|
||||
print(f"- dataset stratified by {stratified}")
|
||||
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
|
||||
self._set_langs (train, test, set_tr_langs, set_te_langs)
|
||||
self._set_labels(_data_tr)
|
||||
self.full_train = _data_tr
|
||||
|
|
@ -56,8 +70,6 @@ class SimpleGfunDataset:
|
|||
return
|
||||
|
||||
def _set_labels(self, data):
|
||||
# self.labels = [i for i in range(28)] # todo hard-coded for rai
|
||||
# self.labels = [i for i in range(3)] # TODO hard coded for sentimnet
|
||||
self.labels = sorted(list(data.label.unique()))
|
||||
|
||||
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
|
||||
|
|
@ -70,17 +82,17 @@ class SimpleGfunDataset:
|
|||
print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]")
|
||||
self.te_langs = self.te_langs.intersection(set(set_te_langs))
|
||||
self.all_langs = self.tr_langs.union(self.te_langs)
|
||||
|
||||
|
||||
return self.tr_langs, self.te_langs, self.all_langs
|
||||
|
||||
|
||||
def _set_datalang(self, data: pd.DataFrame):
|
||||
return {lang: data[data.lang == lang] for lang in self.all_langs}
|
||||
|
||||
def training(self, merge_validation=False, mask_number=False, target_as_csr=False):
|
||||
# TODO some additional pre-processing on the textual data?
|
||||
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
||||
lXtr = {
|
||||
lang: {"text": apply_mask(self.train_data[lang].text.tolist())} # TODO inserting dict for textual data - we still have to manage visual
|
||||
lang: {"text": apply_mask(self.train_data[lang].text.tolist())}
|
||||
for lang in self.tr_langs
|
||||
}
|
||||
if merge_validation:
|
||||
|
|
@ -103,7 +115,6 @@ class SimpleGfunDataset:
|
|||
return lXtr, lYtr
|
||||
|
||||
def test(self, mask_number=False, target_as_csr=False):
|
||||
# TODO some additional pre-processing on the textual data?
|
||||
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
||||
lXte = {
|
||||
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
||||
|
|
@ -162,62 +173,12 @@ class gFunDataset:
|
|||
return mlb
|
||||
|
||||
def _load_dataset(self):
|
||||
if "glami" in self.dataset_dir.lower():
|
||||
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "glami"
|
||||
self.dataset, self.labels, self.data_langs = self._load_glami(
|
||||
self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
elif "rcv" in self.dataset_dir.lower():
|
||||
print(f"- Loading RCV1-2 dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "rcv1-2"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
elif "jrc" in self.dataset_dir.lower():
|
||||
print(f"- Loading JRC dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "jrc"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
# WEBIS-CLS (processed)
|
||||
elif (
|
||||
"cls" in self.dataset_dir.lower()
|
||||
and "unprocessed" not in self.dataset_dir.lower()
|
||||
):
|
||||
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "cls"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
# WEBIS-CLS (unprocessed)
|
||||
elif (
|
||||
"cls" in self.dataset_dir.lower()
|
||||
and "unprocessed" in self.dataset_dir.lower()
|
||||
):
|
||||
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "cls"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
elif "rai" in self.dataset_dir.lower():
|
||||
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "rai"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
|
||||
dataset_dir="~/datasets/rai/csv/train-split-rai.csv",
|
||||
nrows=self.nrows)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
print(f"- Loading dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "rai"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
|
||||
dataset_dir=self.dataset_dir,
|
||||
nrows=self.nrows)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
self.show_dimension()
|
||||
return
|
||||
|
||||
|
|
@ -232,15 +193,12 @@ class gFunDataset:
|
|||
else:
|
||||
print(f"-- Labels: {len(self.labels)}")
|
||||
|
||||
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
||||
def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"):
|
||||
if "csv" in dataset_dir:
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv(
|
||||
# path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
|
||||
#path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
|
||||
path_tr="~/datasets/rai/csv/train-split-rai.csv",
|
||||
path_te="~/datasets/rai/csv/test-split-rai.csv")
|
||||
else:
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
||||
old_dataset = MultilingualDataset(dataset_name="rai").from_csv(
|
||||
path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")),
|
||||
path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv"))
|
||||
)
|
||||
if nrows is not None:
|
||||
if dataset_name == "cls":
|
||||
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
|
||||
|
|
@ -365,27 +323,4 @@ if __name__ == "__main__":
|
|||
data_rai = SimpleGfunDataset()
|
||||
lXtr, lYtr = data_rai.training(mask_number=False)
|
||||
lXte, lYte = data_rai.test(mask_number=False)
|
||||
exit()
|
||||
# import os
|
||||
|
||||
# GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
# RCV_DATAPATH = os.path.expanduser(
|
||||
# "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
# )
|
||||
# JRC_DATAPATH = os.path.expanduser(
|
||||
# "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
# )
|
||||
|
||||
# print("Hello gFunDataset")
|
||||
# dataset = gFunDataset(
|
||||
# dataset_dir=JRC_DATAPATH,
|
||||
# data_langs=None,
|
||||
# is_textual=True,
|
||||
# is_visual=True,
|
||||
# is_multilabel=False,
|
||||
# labels=None,
|
||||
# nrows=13,
|
||||
# )
|
||||
# lXtr, lYtr = dataset.training()
|
||||
# lXte, lYte = dataset.test()
|
||||
# exit(0)
|
||||
exit()
|
||||
|
|
@ -1,7 +1,5 @@
|
|||
from os.path import expanduser, join
|
||||
from dataManager.gFunDataset import gFunDataset, SimpleGfunDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.amazonDataset import AmazonDataset
|
||||
from dataManager.gFunDataset import SimpleGfunDataset
|
||||
|
||||
|
||||
def load_from_pickle(path, dataset_name, nrows):
|
||||
|
|
@ -16,119 +14,14 @@ def load_from_pickle(path, dataset_name, nrows):
|
|||
return loaded
|
||||
|
||||
|
||||
def get_dataset(dataset_name, args):
|
||||
assert dataset_name in [
|
||||
"multinews",
|
||||
"amazon",
|
||||
"rcv1-2",
|
||||
"glami",
|
||||
"cls",
|
||||
"webis",
|
||||
"rai",
|
||||
], "dataset not supported"
|
||||
|
||||
RCV_DATAPATH = expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
def get_dataset(datasetp_path, args):
|
||||
dataset = SimpleGfunDataset(
|
||||
dataset_name="rai",
|
||||
datadir=datasetp_path,
|
||||
textual=True,
|
||||
visual=False,
|
||||
multilabel=False,
|
||||
set_tr_langs=args.tr_langs,
|
||||
set_te_langs=args.te_langs
|
||||
)
|
||||
JRC_DATAPATH = expanduser(
|
||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
)
|
||||
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||
|
||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
|
||||
WEBIS_CLS = expanduser(
|
||||
# "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
||||
"~/datasets/cls-acl10-unprocessed/csv"
|
||||
)
|
||||
|
||||
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
|
||||
|
||||
if dataset_name == "multinews":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
dataset = MultiNewsDataset(
|
||||
expanduser(MULTINEWS_DATAPATH),
|
||||
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
||||
)
|
||||
elif dataset_name == "amazon":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
dataset = AmazonDataset(
|
||||
domains=args.domains,
|
||||
nrows=args.nrows,
|
||||
min_count=args.min_count,
|
||||
max_labels=args.max_labels,
|
||||
)
|
||||
elif dataset_name == "jrc":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=JRC_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=True,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "rcv1-2":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=RCV_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=True,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "glami":
|
||||
if args.save_dataset is False:
|
||||
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
|
||||
else:
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
|
||||
dataset.save_as_pickle(GLAMI_DATAPATH)
|
||||
elif dataset_name == "cls":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=CLS_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "webis":
|
||||
dataset = SimpleGfunDataset(
|
||||
datadir=WEBIS_CLS,
|
||||
textual=True,
|
||||
visual=False,
|
||||
multilabel=False,
|
||||
set_tr_langs=args.tr_langs,
|
||||
set_te_langs=args.te_langs
|
||||
)
|
||||
# dataset = gFunDataset(
|
||||
# dataset_dir=WEBIS_CLS,
|
||||
# is_textual=True,
|
||||
# is_visual=False,
|
||||
# is_multilabel=False,
|
||||
# nrows=args.nrows,
|
||||
# )
|
||||
elif dataset_name == "rai":
|
||||
dataset = SimpleGfunDataset(
|
||||
datadir="~/datasets/rai/csv",
|
||||
textual=True,
|
||||
visual=False,
|
||||
multilabel=False
|
||||
)
|
||||
# dataset = gFunDataset(
|
||||
# dataset_dir=RAI_DATAPATH,
|
||||
# is_textual=True,
|
||||
# is_visual=False,
|
||||
# is_multilabel=False,
|
||||
# nrows=args.nrows
|
||||
# )
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
|||
from gfun.vgfs.multilingualGen import MultilingualGen
|
||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||
from gfun.vgfs.vanillaFun import VanillaFunGen
|
||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
||||
from gfun.vgfs.wceGen import WceGen
|
||||
|
||||
|
||||
|
|
@ -19,7 +18,6 @@ class GeneralizedFunnelling:
|
|||
wce,
|
||||
multilingual,
|
||||
textual_transformer,
|
||||
visual_transformer,
|
||||
langs,
|
||||
num_labels,
|
||||
classification_type,
|
||||
|
|
@ -29,12 +27,9 @@ class GeneralizedFunnelling:
|
|||
eval_batch_size,
|
||||
max_length,
|
||||
textual_lr,
|
||||
visual_lr,
|
||||
epochs,
|
||||
patience,
|
||||
evaluate_step,
|
||||
textual_transformer_name,
|
||||
visual_transformer_name,
|
||||
optimc,
|
||||
device,
|
||||
load_trained,
|
||||
|
|
@ -42,13 +37,14 @@ class GeneralizedFunnelling:
|
|||
probabilistic,
|
||||
aggfunc,
|
||||
load_meta,
|
||||
trained_text_trf=None,
|
||||
textual_transformer_name=None,
|
||||
):
|
||||
# Setting VFGs -----------
|
||||
self.posteriors_vgf = posterior
|
||||
self.wce_vgf = wce
|
||||
self.multilingual_vgf = multilingual
|
||||
self.textual_trf_vgf = textual_transformer
|
||||
self.visual_trf_vgf = visual_transformer
|
||||
self.probabilistic = probabilistic
|
||||
self.num_labels = num_labels
|
||||
self.clf_type = classification_type
|
||||
|
|
@ -58,6 +54,7 @@ class GeneralizedFunnelling:
|
|||
self.cached = True
|
||||
# Textual Transformer VGF params ----------
|
||||
self.textual_trf_name = textual_transformer_name
|
||||
self.trained_text_trf = trained_text_trf
|
||||
self.epochs = epochs
|
||||
self.textual_trf_lr = textual_lr
|
||||
self.textual_scheduler = "ReduceLROnPlateau"
|
||||
|
|
@ -68,10 +65,6 @@ class GeneralizedFunnelling:
|
|||
self.patience = patience
|
||||
self.evaluate_step = evaluate_step
|
||||
self.device = device
|
||||
# Visual Transformer VGF params ----------
|
||||
self.visual_trf_name = visual_transformer_name
|
||||
self.visual_trf_lr = visual_lr
|
||||
self.visual_scheduler = "ReduceLROnPlateau"
|
||||
# Metaclassifier params ------------
|
||||
self.optimc = optimc
|
||||
# -------------------
|
||||
|
|
@ -131,7 +124,6 @@ class GeneralizedFunnelling:
|
|||
self.multilingual_vgf,
|
||||
self.wce_vgf,
|
||||
self.textual_trf_vgf,
|
||||
self.visual_trf_vgf,
|
||||
self.aggfunc,
|
||||
)
|
||||
return self
|
||||
|
|
@ -174,26 +166,10 @@ class GeneralizedFunnelling:
|
|||
patience=self.patience,
|
||||
device=self.device,
|
||||
classification_type=self.clf_type,
|
||||
saved_model=self.trained_text_trf,
|
||||
)
|
||||
self.first_tier_learners.append(transformer_vgf)
|
||||
|
||||
if self.visual_trf_vgf:
|
||||
visual_trasformer_vgf = VisualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name="vit",
|
||||
lr=self.visual_trf_lr,
|
||||
scheduler=self.visual_scheduler,
|
||||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
probabilistic=self.probabilistic,
|
||||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
device=self.device,
|
||||
classification_type=self.clf_type,
|
||||
)
|
||||
self.first_tier_learners.append(visual_trasformer_vgf)
|
||||
|
||||
if "attn" in self.aggfunc:
|
||||
attn_stacking = self.aggfunc.split("_")[1]
|
||||
self.attn_aggregator = AttentionAggregator(
|
||||
|
|
@ -219,7 +195,6 @@ class GeneralizedFunnelling:
|
|||
self.multilingual_vgf,
|
||||
self.wce_vgf,
|
||||
self.textual_trf_vgf,
|
||||
self.visual_trf_vgf,
|
||||
self.aggfunc,
|
||||
)
|
||||
print(f"- model id: {self._model_id}")
|
||||
|
|
@ -309,16 +284,29 @@ class GeneralizedFunnelling:
|
|||
return aggregated
|
||||
|
||||
def _aggregate_mean(self, first_tier_projections):
|
||||
aggregated = {
|
||||
lang: np.zeros(data.shape)
|
||||
for lang, data in first_tier_projections[0].items()
|
||||
}
|
||||
# aggregated = {
|
||||
# lang: np.zeros(data.shape)
|
||||
# for lang, data in first_tier_projections[0].items()
|
||||
# }
|
||||
aggregated = {}
|
||||
for lang_projections in first_tier_projections:
|
||||
for lang, projection in lang_projections.items():
|
||||
if lang not in aggregated:
|
||||
aggregated[lang] = np.zeros(projection.shape)
|
||||
aggregated[lang] += projection
|
||||
|
||||
for lang, projection in aggregated.items():
|
||||
aggregated[lang] /= len(first_tier_projections)
|
||||
def get_denom(lang, projs):
|
||||
den = 0
|
||||
for proj in projs:
|
||||
if lang in proj:
|
||||
den += 1
|
||||
return den
|
||||
|
||||
for lang, _ in aggregated.items():
|
||||
# aggregated[lang] /= len(first_tier_projections)
|
||||
aggregated[lang] /= get_denom(lang, first_tier_projections)
|
||||
|
||||
|
||||
|
||||
return aggregated
|
||||
|
||||
|
|
@ -416,18 +404,6 @@ class GeneralizedFunnelling:
|
|||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained Textual Transformer VGF")
|
||||
if self.visual_trf_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
"models",
|
||||
"vgfs",
|
||||
"visual_transformer",
|
||||
f"visualTransformerGen_{model_id}.pkl",
|
||||
),
|
||||
"rb",
|
||||
print(f"- loaded trained Visual Transformer VGF"),
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
|
||||
if load_meta:
|
||||
with open(
|
||||
|
|
@ -482,7 +458,6 @@ def get_unique_id(
|
|||
multilingual,
|
||||
wce,
|
||||
textual_transformer,
|
||||
visual_transformer,
|
||||
aggfunc,
|
||||
):
|
||||
from datetime import datetime
|
||||
|
|
@ -493,6 +468,5 @@ def get_unique_id(
|
|||
model_id += "m" if multilingual else ""
|
||||
model_id += "w" if wce else ""
|
||||
model_id += "t" if textual_transformer else ""
|
||||
model_id += "v" if visual_transformer else ""
|
||||
model_id += f"_{aggfunc}"
|
||||
return f"{model_id}_{now}"
|
||||
|
|
|
|||
|
|
@ -173,9 +173,18 @@ class NaivePolylingualClassifier:
|
|||
:return: a dictionary of probabilities that each document belongs to each class
|
||||
"""
|
||||
assert self.model is not None, "predict called before fit"
|
||||
assert set(lX.keys()).issubset(
|
||||
set(self.model.keys())
|
||||
), "unknown languages requested in decision function"
|
||||
if not set(lX.keys()).issubset(set(self.model.keys())):
|
||||
langs = set(lX.keys()).intersection(set(self.model.keys()))
|
||||
scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
|
||||
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs)
|
||||
# res = {lang: None for lang in lX.keys()}
|
||||
# for i, lang in enumerate(langs):
|
||||
# res[lang] = scores[i]
|
||||
# return res
|
||||
return {lang: scores[i] for i, lang in enumerate(langs)}
|
||||
# assert set(lX.keys()).issubset(
|
||||
# set(self.model.keys())
|
||||
# ), "unknown languages requested in decision function"
|
||||
langs = list(lX.keys())
|
||||
scores = Parallel(n_jobs=self.n_jobs, max_nbytes=None)(
|
||||
delayed(self.model[lang].predict_proba)(lX[lang]) for lang in langs
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
patience=5,
|
||||
classification_type="multilabel",
|
||||
scheduler="ReduceLROnPlateau",
|
||||
saved_model = None
|
||||
):
|
||||
super().__init__(
|
||||
self._validate_model_name(model_name),
|
||||
|
|
@ -91,6 +92,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
n_jobs=n_jobs,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.saved_model = saved_model
|
||||
self.clf_type = classification_type
|
||||
self.fitted = False
|
||||
print(
|
||||
|
|
@ -109,26 +111,25 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_pretrained_model(self, model_name, num_labels):
|
||||
def load_pretrained_model(self, model_name, num_labels, saved_model=None):
|
||||
if model_name == "google/mt5-small":
|
||||
return MT5ForSequenceClassification(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
else:
|
||||
# model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
|
||||
# model_name = "hf_models/mbert-rai-fewshot-second/checkpoint-19000" # TODO hardcoded to pre-traiend mbert
|
||||
# model_name = "hf_models/mbert-sentiment/checkpoint-1150" # TODO hardcoded to pre-traiend mbert
|
||||
model_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
|
||||
if saved_model:
|
||||
model_name = saved_model
|
||||
else:
|
||||
model_name = "google/bert-base-multilingual-cased"
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
|
||||
def load_tokenizer(self, model_name):
|
||||
# model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
|
||||
return AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def init_model(self, model_name, num_labels):
|
||||
return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer(
|
||||
def init_model(self, model_name, num_labels, saved_model):
|
||||
return self.load_pretrained_model(model_name, num_labels, saved_model), self.load_tokenizer(
|
||||
model_name
|
||||
)
|
||||
|
||||
|
|
@ -148,64 +149,11 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
_l = list(lX.keys())[0]
|
||||
self.num_labels = lY[_l].shape[-1]
|
||||
self.model, self.tokenizer = self.init_model(
|
||||
self.model_name, num_labels=self.num_labels
|
||||
self.model_name, num_labels=self.num_labels, saved_model=self.saved_model,
|
||||
)
|
||||
|
||||
self.model.to("cuda")
|
||||
|
||||
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||
# lX, lY, split=0.2, seed=42, modality="text"
|
||||
# )
|
||||
#
|
||||
# tra_dataloader = self.build_dataloader(
|
||||
# tr_lX,
|
||||
# tr_lY,
|
||||
# processor_fn=self._tokenize,
|
||||
# torchDataset=MultilingualDatasetTorch,
|
||||
# batch_size=self.batch_size,
|
||||
# split="train",
|
||||
# shuffle=True,
|
||||
# )
|
||||
#
|
||||
# val_dataloader = self.build_dataloader(
|
||||
# val_lX,
|
||||
# val_lY,
|
||||
# processor_fn=self._tokenize,
|
||||
# torchDataset=MultilingualDatasetTorch,
|
||||
# batch_size=self.batch_size_eval,
|
||||
# split="val",
|
||||
# shuffle=False,
|
||||
# )
|
||||
#
|
||||
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
#
|
||||
# trainer = Trainer(
|
||||
# model=self.model,
|
||||
# optimizer_name="adamW",
|
||||
# lr=self.lr,
|
||||
# device=self.device,
|
||||
# loss_fn=torch.nn.CrossEntropyLoss(),
|
||||
# print_steps=self.print_steps,
|
||||
# evaluate_step=self.evaluate_step,
|
||||
# patience=self.patience,
|
||||
# experiment_name=experiment_name,
|
||||
# checkpoint_path=os.path.join(
|
||||
# "models",
|
||||
# "vgfs",
|
||||
# "trained_transformer",
|
||||
# self._format_model_name(self.model_name),
|
||||
# ),
|
||||
# vgf_name="textual_trf",
|
||||
# classification_type=self.clf_type,
|
||||
# n_jobs=self.n_jobs,
|
||||
# scheduler_name=self.scheduler,
|
||||
# )
|
||||
# trainer.train(
|
||||
# train_dataloader=tra_dataloader,
|
||||
# eval_dataloader=val_dataloader,
|
||||
# epochs=self.epochs,
|
||||
# )
|
||||
|
||||
if self.probabilistic:
|
||||
transformed = self.transform(lX)
|
||||
self.feature2posterior_projector.fit(transformed, lY)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from os.path import expanduser
|
||||
from os.path import expanduser, join
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
|
|
@ -19,13 +19,8 @@ import pandas as pd
|
|||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"]
|
||||
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] # "lang"
|
||||
WEBIS_D_COLUMNS = ['Unnamed: 0', 'asin', 'category', 'original_rating', 'label', 'title', 'text', 'summary'] # "lang"
|
||||
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"]
|
||||
MAX_LEN = 128
|
||||
# DATASET_NAME = "rai"
|
||||
# DATASET_NAME = "rai-multilingual-2000"
|
||||
# DATASET_NAME = "webis-cls"
|
||||
|
||||
|
||||
def init_callbacks(patience=-1, nosave=False):
|
||||
|
|
@ -35,13 +30,17 @@ def init_callbacks(patience=-1, nosave=False):
|
|||
return callbacks
|
||||
|
||||
|
||||
def init_model(model_name, nlabels):
|
||||
def init_model(model_name, nlabels, saved_model=None):
|
||||
if model_name == "mbert":
|
||||
# hf_name = "bert-base-multilingual-cased"
|
||||
hf_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
|
||||
# hf_name = "hf_models/mbert-rai-fewshot-second/checkpoint-9000"
|
||||
if saved_model is None:
|
||||
hf_name = "bert-base-multilingual-cased"
|
||||
else:
|
||||
hf_name = saved_model
|
||||
elif model_name == "xlm-roberta":
|
||||
hf_name = "xlm-roberta-base"
|
||||
if saved_model is None:
|
||||
hf_name = "xlm-roberta-base"
|
||||
else:
|
||||
hf_name = saved_model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_name)
|
||||
|
|
@ -50,43 +49,41 @@ def init_model(model_name, nlabels):
|
|||
|
||||
|
||||
def main(args):
|
||||
tokenizer, model = init_model(args.model, args.nlabels)
|
||||
saved_model = args.savedmodel
|
||||
trainlang = args.trainlangs
|
||||
datapath = args.datapath
|
||||
|
||||
tokenizer, model = init_model(args.model, args.nlabels, saved_model=saved_model)
|
||||
|
||||
data = load_dataset(
|
||||
"csv",
|
||||
data_files = {
|
||||
"train": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/train.balanced.csv"),
|
||||
"test": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/test.balanced.csv")
|
||||
# "train": expanduser(f"~/datasets/rai/csv/train-{DATASET_NAME}.csv"),
|
||||
# "test": expanduser(f"~/datasets/rai/csv/test-{DATASET_NAME}.csv")
|
||||
# "train": expanduser(f"~/datasets/rai/csv/train.small.csv"),
|
||||
# "test": expanduser(f"~/datasets/rai/csv/test.small.csv")
|
||||
"train": expanduser(join(datapath, "train.csv")),
|
||||
"test": expanduser(join(datapath, "test.small.csv"))
|
||||
}
|
||||
)
|
||||
|
||||
def filter_dataset(dataset, lang):
|
||||
indices = [i for i, l in enumerate(dataset["lang"]) if l == lang]
|
||||
dataset = dataset.select(indices)
|
||||
return dataset
|
||||
|
||||
if trainlang is not None:
|
||||
data["train"] = filter_dataset(data["train"], lang=trainlang)
|
||||
|
||||
def process_sample_rai(sample):
|
||||
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
|
||||
labels = sample["label"]
|
||||
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
|
||||
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True)
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
def process_sample_webis(sample):
|
||||
inputs = sample["text"]
|
||||
labels = sample["label"]
|
||||
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
|
||||
data = data.map(
|
||||
# process_sample_rai,
|
||||
process_sample_webis,
|
||||
process_sample_rai,
|
||||
batched=True,
|
||||
num_proc=4,
|
||||
load_from_cache_file=True,
|
||||
# remove_columns=RAI_D_COLUMNS,
|
||||
remove_columns=WEBIS_D_COLUMNS,
|
||||
remove_columns=RAI_D_COLUMNS,
|
||||
)
|
||||
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
||||
data.set_format("torch")
|
||||
|
|
@ -107,8 +104,8 @@ def main(args):
|
|||
recall_metric = evaluate.load("recall")
|
||||
|
||||
training_args = TrainingArguments(
|
||||
# output_dir=f"hf_models/{args.model}-rai",
|
||||
output_dir=f"hf_models/{args.model}-sentiment-balanced",
|
||||
output_dir=f"hf_models/{args.model}-fewshot-full" if trainlang is None else f"hf_models/{args.model}-zeroshot-full",
|
||||
run_name="model-zeroshot" if trainlang is not None else "model-fewshot",
|
||||
do_train=True,
|
||||
evaluation_strategy="steps",
|
||||
per_device_train_batch_size=args.batch,
|
||||
|
|
@ -130,8 +127,6 @@ def main(args):
|
|||
save_strategy="no" if args.nosave else "steps",
|
||||
save_total_limit=2,
|
||||
eval_steps=args.stepeval,
|
||||
# run_name=f"{args.model}-rai-stratified",
|
||||
run_name=f"{args.model}-sentiment",
|
||||
disable_tqdm=False,
|
||||
log_level="warning",
|
||||
report_to=["wandb"] if args.wandb else "none",
|
||||
|
|
@ -142,7 +137,6 @@ def main(args):
|
|||
|
||||
def compute_metrics(eval_preds):
|
||||
preds = eval_preds.predictions.argmax(-1)
|
||||
# targets = eval_preds.label_ids.argmax(-1)
|
||||
targets = eval_preds.label_ids
|
||||
setting = "macro"
|
||||
f1_score_macro = f1_metric.compute(
|
||||
|
|
@ -170,7 +164,9 @@ def main(args):
|
|||
|
||||
if args.wandb:
|
||||
import wandb
|
||||
wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-rai", config=vars(args))
|
||||
wandb.init(entity="andreapdr", project=f"gfun",
|
||||
name="model-zeroshot-full" if trainlang is not None else "model-fewshot-full",
|
||||
config=vars(args))
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
|
|
@ -188,17 +184,21 @@ def main(args):
|
|||
trainer.train()
|
||||
|
||||
print("- Testing:")
|
||||
test_results = trainer.evaluate(eval_dataset=data["test"], metric_key_prefix="test")
|
||||
test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test")
|
||||
pprint(test_results.metrics)
|
||||
save_preds(data["test"], test_results.predictions)
|
||||
save_preds(data["test"], test_results.predictions, trainlang)
|
||||
exit()
|
||||
|
||||
def save_preds(dataset, predictions):
|
||||
def save_preds(dataset, predictions, trainlang=None):
|
||||
df = pd.DataFrame()
|
||||
df["langs"] = dataset["lang"]
|
||||
df["labels"] = dataset["labels"]
|
||||
df["preds"] = predictions.argmax(axis=1)
|
||||
df.to_csv("results/lang-specific.mbert.webis.csv", index=False)
|
||||
if trainlang is not None:
|
||||
df.to_csv(f"results/zeroshot.{trainlang}.model.csv", index=False)
|
||||
else:
|
||||
df.to_csv("results/fewshot.model.csv", index=False)
|
||||
return
|
||||
|
||||
|
||||
|
|
@ -210,16 +210,18 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--nlabels", type=int, metavar="", default=28)
|
||||
parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",)
|
||||
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
||||
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
|
||||
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
|
||||
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
|
||||
parser.add_argument("--epochs", type=int, metavar="", default=10, help="Set epochs")
|
||||
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
|
||||
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
|
||||
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
|
||||
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
|
||||
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
|
||||
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
|
||||
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
|
||||
parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
|
||||
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
|
||||
parser.add_argument("--trainlang", default=None, type=str, help="set training language for zero-shot experiments" )
|
||||
parser.add_argument("--datapath", type=str, default="data", help="path to the csv dataset. Dir should contain both a train.csv and a test.csv file")
|
||||
parser.add_argument("--savedmodel", type=str, default="hf_models/mbert-rai-fewshot-second/checkpoint-9000")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
|
|||
30
main.py
30
main.py
|
|
@ -10,8 +10,6 @@ import pandas as pd
|
|||
|
||||
"""
|
||||
TODO:
|
||||
- General:
|
||||
[!] zero-shot setup
|
||||
- Docs:
|
||||
- add documentations sphinx
|
||||
"""
|
||||
|
|
@ -27,13 +25,11 @@ def get_config_name(args):
|
|||
config_name += "M+"
|
||||
if args.textual_transformer:
|
||||
config_name += f"TT_{args.textual_trf_name}+"
|
||||
if args.visual_transformer:
|
||||
config_name += f"VT_{args.visual_trf_name}+"
|
||||
return config_name.rstrip("+")
|
||||
|
||||
|
||||
def main(args):
|
||||
dataset = get_dataset(args.dataset, args)
|
||||
dataset = get_dataset(args.datadir, args)
|
||||
lX, lY = dataset.training(merge_validation=True)
|
||||
lX_te, lY_te = dataset.test()
|
||||
|
||||
|
|
@ -47,13 +43,12 @@ def main(args):
|
|||
args.multilingual,
|
||||
args.multilingual,
|
||||
args.textual_transformer,
|
||||
args.visual_transformer,
|
||||
]
|
||||
), "At least one of VGF must be True"
|
||||
|
||||
gfun = GeneralizedFunnelling(
|
||||
# dataset params ----------------------
|
||||
dataset_name=args.dataset,
|
||||
dataset_name=dataset,
|
||||
langs=dataset.langs(),
|
||||
num_labels=dataset.num_labels(),
|
||||
classification_type=args.clf_type,
|
||||
|
|
@ -67,24 +62,15 @@ def main(args):
|
|||
# Transformer VGF params --------------
|
||||
textual_transformer=args.textual_transformer,
|
||||
textual_transformer_name=args.textual_trf_name,
|
||||
trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
|
||||
batch_size=args.batch_size,
|
||||
eval_batch_size=args.eval_batch_size,
|
||||
epochs=args.epochs,
|
||||
textual_lr=args.textual_lr,
|
||||
visual_lr=args.visual_lr,
|
||||
max_length=args.max_length,
|
||||
patience=args.patience,
|
||||
evaluate_step=args.evaluate_step,
|
||||
device=args.device,
|
||||
# Visual Transformer VGF params --------------
|
||||
visual_transformer=args.visual_transformer,
|
||||
visual_transformer_name=args.visual_trf_name,
|
||||
# batch_size=args.batch_size,
|
||||
# epochs=args.epochs,
|
||||
# lr=args.lr,
|
||||
# patience=args.patience,
|
||||
# evaluate_step=args.evaluate_step,
|
||||
# device="cuda",
|
||||
# General params ---------------------
|
||||
probabilistic=args.features,
|
||||
aggfunc=args.aggfunc,
|
||||
|
|
@ -145,7 +131,7 @@ def main(args):
|
|||
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||
|
||||
config["gFun"]["timing"] = f"{timeval - tinit:.2f}"
|
||||
csvlogger = CsvLogger(outfile="results/random.log.csv").log_lang_results(lang_metrics_gfun, config)
|
||||
csvlogger = CsvLogger(outfile="results/gfun.log.csv").log_lang_results(lang_metrics_gfun, config, notes="")
|
||||
save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"])
|
||||
|
||||
|
||||
|
|
@ -162,7 +148,7 @@ def save_preds(preds, targets, config="unk", dataset="unk"):
|
|||
df["langs"] = _langs
|
||||
df["labels"] = _targets
|
||||
df["preds"] = _preds
|
||||
df.to_csv(f"results/lang-specific.gfun.{config}.{dataset}.csv", index=False)
|
||||
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -174,7 +160,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--tr_langs", nargs="+", default=None)
|
||||
parser.add_argument("--te_langs", nargs="+", default=None)
|
||||
# Dataset parameters -------------------
|
||||
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
|
||||
parser.add_argument("-d", "--datadir", type=str, default=None, help="dir to dataset. It should contain both a train.csv and a test.csv file")
|
||||
parser.add_argument("--domains", type=str, default="all")
|
||||
parser.add_argument("--nrows", type=int, default=None)
|
||||
parser.add_argument("--min_count", type=int, default=10)
|
||||
|
|
@ -186,7 +172,6 @@ if __name__ == "__main__":
|
|||
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||
parser.add_argument("-w", "--wce", action="store_true")
|
||||
parser.add_argument("-t", "--textual_transformer", action="store_true")
|
||||
parser.add_argument("-v", "--visual_transformer", action="store_true")
|
||||
parser.add_argument("--n_jobs", type=int, default=-1)
|
||||
parser.add_argument("--optimc", action="store_true")
|
||||
parser.add_argument("--features", action="store_false")
|
||||
|
|
@ -200,9 +185,6 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--patience", type=int, default=5)
|
||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||
# Visual Transformer parameters --------------
|
||||
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
||||
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
||||
# logging
|
||||
parser.add_argument("--wandb", action="store_true")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue