Compare commits

..

13 Commits

12 changed files with 476 additions and 99 deletions

3
.gitignore vendored
View File

@ -182,4 +182,5 @@ scripts/
logger/*
explore_data.ipynb
run.sh
wandb
wandb
local_datasets

View File

@ -1,5 +1,6 @@
import sys
import os
import xml.etree.ElementTree as ET
sys.path.append(os.getcwd())
@ -8,13 +9,87 @@ import re
from dataManager.multilingualDataset import MultilingualDataset
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
LANGS = ["de", "en", "fr", "jp"]
CLS_UNPROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-unprocessed/")
# LANGS = ["de", "en", "fr", "jp"]
LANGS = ["de", "en", "fr"]
DOMAINS = ["books", "dvd", "music"]
regex = r":\d+"
subst = ""
def load_unprocessed_cls(reduce_target_space=False):
data = {}
data_tr = []
data_te = []
c_tr = 0
c_te = 0
for lang in LANGS:
data[lang] = {}
for domain in DOMAINS:
data[lang][domain] = {}
print(f"lang: {lang}, domain: {domain}")
for split in ["train", "test"]:
domain_data = []
fdir = os.path.join(
CLS_UNPROCESSED_DATA_DIR, lang, domain, f"{split}.review"
)
tree = ET.parse(fdir)
root = tree.getroot()
for child in root:
if reduce_target_space:
rating = np.zeros(3, dtype=int)
original_rating = int(float(child.find("rating").text))
if original_rating < 3:
new_rating = 1
elif original_rating > 3:
new_rating = 3
else:
new_rating = 2
rating[new_rating - 1] = 1
# rating = new_rating
else:
rating = np.zeros(5, dtype=int)
rating[int(float(child.find("rating").text)) - 1] = 1
# rating = new_rating
# if split == "train":
# target_data = data_tr
# current_count = len(target_data)
# c_tr = +1
# else:
# target_data = data_te
# current_count = len(target_data)
# c_te = +1
domain_data.append(
# target_data.append(
{
"asin": child.find("asin").text
if child.find("asin") is not None
else None,
# "category": child.find("category").text
# if child.find("category") is not None
# else None,
"category": domain,
# "rating": child.find("rating").text
# if child.find("rating") is not None
# else None,
"rating": rating,
"title": child.find("title").text
if child.find("title") is not None
else None,
"text": child.find("text").text
if child.find("text") is not None
else None,
"summary": child.find("summary").text
if child.find("summary") is not None
else None,
"lang": lang,
}
)
data[lang][domain].update({split: domain_data})
return data
def load_cls():
data = {}
for lang in LANGS:
@ -24,7 +99,7 @@ def load_cls():
train = (
open(
os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
CLS_UNPROCESSED_DATA_DIR, lang, domain, "train.processed"
),
"r",
)
@ -34,7 +109,7 @@ def load_cls():
test = (
open(
os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
CLS_UNPROCESSED_DATA_DIR, lang, domain, "test.processed"
),
"r",
)
@ -59,18 +134,33 @@ def process_data(line):
if __name__ == "__main__":
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
data = load_cls()
multilingualDataset = MultilingualDataset(dataset_name="cls")
for lang in LANGS:
# TODO: just using book domain atm
Xtr = [text[0] for text in data[lang]["books"]["train"]]
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
print(f"datapath: {CLS_UNPROCESSED_DATA_DIR}")
# data = load_cls()
data = load_unprocessed_cls(reduce_target_space=True)
multilingualDataset = MultilingualDataset(dataset_name="webis-cls-unprocessed")
Xte = [text[0] for text in data[lang]["books"]["test"]]
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
for lang in LANGS:
# Xtr = [text["summary"] for text in data[lang]["books"]["train"]]
Xtr = [text["text"] for text in data[lang]["books"]["train"]]
Xtr += [text["text"] for text in data[lang]["dvd"]["train"]]
Xtr += [text["text"] for text in data[lang]["music"]["train"]]
Ytr =[text["rating"] for text in data[lang]["books"]["train"]]
Ytr += [text["rating"] for text in data[lang]["dvd"]["train"]]
Ytr += [text["rating"] for text in data[lang]["music"]["train"]]
Ytr = np.vstack(Ytr)
Xte = [text["text"] for text in data[lang]["books"]["test"]]
Xte += [text["text"] for text in data[lang]["dvd"]["test"]]
Xte += [text["text"] for text in data[lang]["music"]["test"]]
Yte = [text["rating"] for text in data[lang]["books"]["test"]]
Yte += [text["rating"] for text in data[lang]["dvd"]["test"]]
Yte += [text["rating"] for text in data[lang]["music"]["test"]]
Yte = np.vstack(Yte)
multilingualDataset.add(
lang=lang,
@ -82,5 +172,7 @@ if __name__ == "__main__":
te_ids=None,
)
multilingualDataset.save(
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
os.path.expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
)
)

View File

@ -62,14 +62,29 @@ class gFunDataset:
)
self.mlb = self.get_label_binarizer(self.labels)
elif "cls" in self.dataset_dir.lower():
print(f"- Loading CLS dataset from {self.dataset_dir}")
# WEBIS-CLS (processed)
elif (
"cls" in self.dataset_dir.lower()
and "unprocessed" not in self.dataset_dir.lower()
):
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
# WEBIS-CLS (unprocessed)
elif (
"cls" in self.dataset_dir.lower()
and "unprocessed" in self.dataset_dir.lower()
):
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension()
return

View File

@ -23,6 +23,7 @@ def get_dataset(dataset_name, args):
"rcv1-2",
"glami",
"cls",
"webis",
], "dataset not supported"
RCV_DATAPATH = expanduser(
@ -37,6 +38,10 @@ def get_dataset(dataset_name, args):
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
WEBIS_CLS = expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
)
if dataset_name == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
@ -91,6 +96,15 @@ def get_dataset(dataset_name, args):
is_multilabel=False,
nrows=args.nrows,
)
elif dataset_name == "webis":
dataset = gFunDataset(
dataset_dir=WEBIS_CLS,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset

View File

@ -1,8 +1,9 @@
from joblib import Parallel, delayed
from collections import defaultdict
from evaluation.metrics import *
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
# from evaluation.metrics import *
import numpy as np
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score, precision_score, recall_score
def evaluation_metrics(y, y_, clf_type):
@ -13,13 +14,17 @@ def evaluation_metrics(y, y_, clf_type):
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
f1_score(y, y_, average="macro", zero_division=1),
f1_score(y, y_, average="micro"),
precision_score(y, y_, zero_division=1, average="macro"),
recall_score(y, y_, zero_division=1, average="macro"),
)
elif clf_type == "multilabel":
return (
macroF1(y, y_),
microF1(y, y_),
macroK(y, y_),
microK(y, y_),
f1_score(y, y_, average="macro", zero_division=1),
f1_score(y, y_, average="micro"),
0,
0,
# macroK(y, y_),
# microK(y, y_),
)
else:
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
@ -48,8 +53,10 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
if clf_type == "multilabel":
for lang in l_eval.keys():
macrof1, microf1, macrok, microk = l_eval[lang]
metrics.append([macrof1, microf1, macrok, microk])
# macrof1, microf1, macrok, microk = l_eval[lang]
# metrics.append([macrof1, microf1, macrok, microk])
macrof1, microf1, precision, recall = l_eval[lang]
metrics.append([macrof1, microf1, precision, recall])
if phase != "validation":
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
averages = np.mean(np.array(metrics), axis=0)
@ -69,12 +76,15 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
# "acc10", # "accuracy-at-10",
"MF1", # "macro-F1",
"mF1", # "micro-F1",
"precision",
"recall"
]
for lang in l_eval.keys():
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
acc, macrof1, microf1 = l_eval[lang]
acc, macrof1, microf1, precision, recall= l_eval[lang]
# metrics.append([acc, top5, top10, macrof1, microf1])
metrics.append([acc, macrof1, microf1])
# metrics.append([acc, macrof1, microf1])
metrics.append([acc, macrof1, microf1, precision, recall])
for m, v in zip(_metrics, l_eval[lang]):
lang_metrics[m][lang] = v
@ -82,7 +92,8 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
if phase != "validation":
print(
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
# f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f} pr = {precision:.3f} re = {recall:.3f}"
)
averages = np.mean(np.array(metrics), axis=0)
if verbose:

View File

@ -124,6 +124,16 @@ class GeneralizedFunnelling:
epochs=self.epochs,
attn_stacking_type=attn_stacking,
)
self._model_id = get_unique_id(
self.dataset_name,
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc,
)
return self
if self.posteriors_vgf:
@ -317,7 +327,7 @@ class GeneralizedFunnelling:
for vgf in self.first_tier_learners:
vgf_config = vgf.get_config()
c.update(vgf_config)
c.update({vgf_config["name"]: vgf_config})
gfun_config = {
"id": self._model_id,
@ -372,6 +382,7 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained VanillaFun VGF")
if self.multilingual_vgf:
with open(
os.path.join(
@ -380,6 +391,7 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained Multilingual VGF")
if self.wce_vgf:
with open(
os.path.join(
@ -388,20 +400,38 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained WCE VGF")
if self.textual_trf_vgf:
with open(
os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
"models",
"vgfs",
"textual_transformer",
f"textualTransformerGen_{model_id}.pkl",
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained Textual Transformer VGF")
if self.visual_trf_vgf:
with open(
os.path.join(
"models",
"vgfs",
"visual_transformer",
f"visualTransformerGen_{model_id}.pkl",
),
"rb",
print(f"- loaded trained Visual Transformer VGF"),
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if load_meta:
with open(
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
) as f:
metaclassifier = pickle.load(f)
print(f"- loaded trained metaclassifier")
else:
metaclassifier = None
return first_tier_learners, metaclassifier, vectorizer

View File

@ -45,11 +45,12 @@ class MT5ForSequenceClassification(nn.Module):
def save_pretrained(self, checkpoint_dir):
torch.save(self.state_dict(), checkpoint_dir + ".pt")
return
return self
def from_pretrained(self, checkpoint_dir):
checkpoint_dir += ".pt"
return self.load_state_dict(torch.load(checkpoint_dir))
self.load_state_dict(torch.load(checkpoint_dir))
return self
class TextualTransformerGen(ViewGen, TransformerGen):
@ -113,6 +114,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
model_name, num_labels=num_labels, output_hidden_states=True
)
else:
model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
@ -144,58 +146,60 @@ class TextualTransformerGen(ViewGen, TransformerGen):
self.model_name, num_labels=self.num_labels
)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
lX, lY, split=0.2, seed=42, modality="text"
)
self.model.to("cuda")
tra_dataloader = self.build_dataloader(
tr_lX,
tr_lY,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size,
split="train",
shuffle=True,
)
val_dataloader = self.build_dataloader(
val_lX,
val_lY,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size_eval,
split="val",
shuffle=False,
)
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
lr=self.lr,
device=self.device,
loss_fn=torch.nn.CrossEntropyLoss(),
print_steps=self.print_steps,
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path=os.path.join(
"models",
"vgfs",
"transformer",
self._format_model_name(self.model_name),
),
vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
scheduler_name=self.scheduler,
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=val_dataloader,
epochs=self.epochs,
)
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
# lX, lY, split=0.2, seed=42, modality="text"
# )
#
# tra_dataloader = self.build_dataloader(
# tr_lX,
# tr_lY,
# processor_fn=self._tokenize,
# torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size,
# split="train",
# shuffle=True,
# )
#
# val_dataloader = self.build_dataloader(
# val_lX,
# val_lY,
# processor_fn=self._tokenize,
# torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size_eval,
# split="val",
# shuffle=False,
# )
#
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
#
# trainer = Trainer(
# model=self.model,
# optimizer_name="adamW",
# lr=self.lr,
# device=self.device,
# loss_fn=torch.nn.CrossEntropyLoss(),
# print_steps=self.print_steps,
# evaluate_step=self.evaluate_step,
# patience=self.patience,
# experiment_name=experiment_name,
# checkpoint_path=os.path.join(
# "models",
# "vgfs",
# "trained_transformer",
# self._format_model_name(self.model_name),
# ),
# vgf_name="textual_trf",
# classification_type=self.clf_type,
# n_jobs=self.n_jobs,
# scheduler_name=self.scheduler,
# )
# trainer.train(
# train_dataloader=tra_dataloader,
# eval_dataloader=val_dataloader,
# epochs=self.epochs,
# )
if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY)
@ -224,7 +228,6 @@ class TextualTransformerGen(ViewGen, TransformerGen):
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
# TODO: check this
if isinstance(self.model, MT5ForSequenceClassification):
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
else:
@ -277,4 +280,4 @@ class TextualTransformerGen(ViewGen, TransformerGen):
def get_config(self):
c = super().get_config()
return {"textual_trf": c}
return {"name": "textual-trasnformer VGF", "textual_trf": c}

View File

@ -65,3 +65,6 @@ class VanillaFunGen(ViewGen):
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def get_config(self):
return {"name": "Vanilla Funnelling VGF"}

View File

@ -186,4 +186,4 @@ class VisualTransformerGen(ViewGen, TransformerGen):
return self
def get_config(self):
return {"visual_trf": super().get_config()}
return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}

191
hf_trainer.py Normal file
View File

@ -0,0 +1,191 @@
import torch
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
TrainingArguments,
)
from gfun.vgfs.commons import Trainer
from datasets import load_dataset, DatasetDict
from transformers import Trainer
import transformers
import evaluate
transformers.logging.set_verbosity_error()
def init_callbacks(patience=-1, nosave=False):
callbacks = []
if patience != -1 and not nosave:
callbacks.append(transformers.EarlyStoppingCallback(early_stopping_patience=patience))
return callbacks
def init_model(model_name):
if model_name == "mbert":
hf_name = "bert-base-multilingual-cased"
elif model_name == "xlm-roberta":
hf_name = "xlm-roberta-base"
else:
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained(hf_name)
model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=3)
return tokenizer, model
def main(args):
tokenizer, model = init_model(args.model)
data = load_dataset(
"json",
data_files={
"train": "local_datasets/webis-cls/all-domains/train.json",
"test": "local_datasets/webis-cls/all-domains/test.json",
},
)
def process_sample(sample):
inputs = sample["text"]
ratings = [r - 1 for r in sample["rating"]]
targets = torch.zeros((len(inputs), 3), dtype=float)
lang_mapper = {
lang: lang_id for lang_id, lang in enumerate(set(sample["lang"]))
}
lang_ids = [lang_mapper[l] for l in sample["lang"]]
for i, r in enumerate(ratings):
targets[i][r - 1] = 1
model_inputs = tokenizer(inputs, max_length=512, truncation=True)
model_inputs["labels"] = targets
model_inputs["lang_ids"] = torch.tensor(lang_ids)
return model_inputs
data = data.map(
process_sample,
batched=True,
num_proc=4,
load_from_cache_file=True,
remove_columns=["text", "category", "rating", "summary", "title"],
)
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
data.set_format("torch")
data = DatasetDict(
{
"train": train_val_splits["train"],
"validation": train_val_splits["test"],
"test": data["test"],
}
)
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
callbacks = init_callbacks(args.patience, args.nosave)
f1_metric = evaluate.load("f1")
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
training_args = TrainingArguments(
output_dir=f"{args.model}-sentiment",
do_train=True,
evaluation_strategy="steps",
per_device_train_batch_size=args.batch,
per_device_eval_batch_size=args.batch,
gradient_accumulation_steps=args.gradacc,
eval_accumulation_steps=10,
learning_rate=args.lr,
weight_decay=0.1,
max_grad_norm=5.0,
num_train_epochs=args.epochs,
lr_scheduler_type=args.scheduler,
warmup_steps=1000,
logging_strategy="steps",
logging_first_step=True,
logging_steps=args.steplog,
seed=42,
fp16=args.fp16,
load_best_model_at_end=False if args.nosave else True,
save_strategy="no" if args.nosave else "steps",
save_total_limit=3,
eval_steps=args.stepeval,
run_name=f"{args.model}-sentiment-run",
disable_tqdm=False,
log_level="warning",
report_to=["wandb"] if args.wandb else "none",
optim="adamw_torch",
)
def compute_metrics(eval_preds):
preds = eval_preds.predictions.argmax(-1)
targets = eval_preds.label_ids.argmax(-1)
setting = "macro"
f1_score_macro = f1_metric.compute(
predictions=preds, references=targets, average="macro"
)
f1_score_micro = f1_metric.compute(
predictions=preds, references=targets, average="micro"
)
accuracy_score = accuracy_metric.compute(predictions=preds, references=targets)
precision_score = precision_metric.compute(
predictions=preds, references=targets, average=setting, zero_division=1
)
recall_score = recall_metric.compute(
predictions=preds, references=targets, average=setting, zero_division=1
)
results = {
"macro_f1score": f1_score_macro["f1"],
"micro_f1score": f1_score_micro["f1"],
"accuracy": accuracy_score["accuracy"],
"precision": precision_score["precision"],
"recall": recall_score["recall"],
}
results = {k: round(v, 4) for k, v in results.items()}
return results
if args.wandb:
import wandb
wandb.init(entity="andreapdr", project=f"gfun-senti-hf", name="mbert-sent", config=vars(args))
trainer = Trainer(
model=model,
args=training_args,
train_dataset=data["train"],
eval_dataset=data["validation"],
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
)
print("- Training:")
trainer.train()
print("- Testing:")
test_results = trainer.evaluate(eval_dataset=data["test"])
print(test_results)
exit()
if __name__ == "__main__":
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, metavar="", default="mbert")
parser.add_argument("--lr", type=float, metavar="", default=1e-5, help="Set learning rate",)
parser.add_argument("--scheduler", type=str, metavar="", default="linear", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
parser.add_argument("--batch", type=int, metavar="", default=16, help="Set batch size")
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
parser.add_argument("--steplog", type=int, metavar="", default=100, help="Log training every n steps")
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
# parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
args = parser.parse_args()
main(args)

13
main.py
View File

@ -1,8 +1,5 @@
import os
import wandb
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser
from time import time
@ -12,16 +9,9 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- Transformers VGFs:
- scheduler with warmup and cosine
- freeze params method
- General:
[!] zero-shot setup
- CLS dataset is loading only "books" domain data
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
- Attention Aggregator:
- experiment with weight init of Attention-aggregator
- FFNN posterior-probabilities' dependent
- Docs:
- add documentations sphinx
"""
@ -150,7 +140,6 @@ def main(args):
wandb.log(gfun_res)
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
if __name__ == "__main__":
@ -178,7 +167,7 @@ if __name__ == "__main__":
parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters ---------------
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--textual_trf_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128)

28
run-senti.sh Normal file
View File

@ -0,0 +1,28 @@
#!bin/bash
config="-m"
echo "[Running gFun config: $config]"
epochs=100
njobs=-1
clf=singlelabel
patience=5
eval_every=5
text_len=256
text_lr=1e-4
bsize=64
txt_model=mbert
python main.py $config \
-d webis \
--epochs $epochs \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model \
--load_trained webis_pmwt_mean_230621