185 lines
6.7 KiB
Python
185 lines
6.7 KiB
Python
import wandb
|
|
|
|
from argparse import ArgumentParser
|
|
from time import time
|
|
|
|
from dataManager.utils import get_dataset
|
|
from evaluation.evaluate import evaluate, log_eval
|
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|
|
|
"""
|
|
TODO:
|
|
- General:
|
|
[!] zero-shot setup
|
|
- CLS dataset is loading only "books" domain data
|
|
- Docs:
|
|
- add documentations sphinx
|
|
"""
|
|
|
|
|
|
def get_config_name(args):
|
|
config_name = ""
|
|
if args.posteriors:
|
|
config_name += "P+"
|
|
if args.wce:
|
|
config_name += "W+"
|
|
if args.multilingual:
|
|
config_name += "M+"
|
|
if args.textual_transformer:
|
|
config_name += f"TT_{args.textual_trf_name}+"
|
|
if args.visual_transformer:
|
|
config_name += f"VT_{args.visual_trf_name}+"
|
|
return config_name.rstrip("+")
|
|
|
|
|
|
def main(args):
|
|
dataset = get_dataset(args.dataset, args)
|
|
lX, lY = dataset.training()
|
|
lX_te, lY_te = dataset.test()
|
|
|
|
tinit = time()
|
|
|
|
if args.load_trained is None:
|
|
assert any(
|
|
[
|
|
args.posteriors,
|
|
args.wce,
|
|
args.multilingual,
|
|
args.multilingual,
|
|
args.textual_transformer,
|
|
args.visual_transformer,
|
|
]
|
|
), "At least one of VGF must be True"
|
|
|
|
gfun = GeneralizedFunnelling(
|
|
# dataset params ----------------------
|
|
dataset_name=args.dataset,
|
|
langs=dataset.langs(),
|
|
num_labels=dataset.num_labels(),
|
|
classification_type=args.clf_type,
|
|
# Posterior VGF params ----------------
|
|
posterior=args.posteriors,
|
|
# Multilingual VGF params -------------
|
|
multilingual=args.multilingual,
|
|
embed_dir="~/resources/muse_embeddings",
|
|
# WCE VGF params ----------------------
|
|
wce=args.wce,
|
|
# Transformer VGF params --------------
|
|
textual_transformer=args.textual_transformer,
|
|
textual_transformer_name=args.textual_trf_name,
|
|
batch_size=args.batch_size,
|
|
eval_batch_size=args.eval_batch_size,
|
|
epochs=args.epochs,
|
|
textual_lr=args.textual_lr,
|
|
visual_lr=args.visual_lr,
|
|
max_length=args.max_length,
|
|
patience=args.patience,
|
|
evaluate_step=args.evaluate_step,
|
|
device=args.device,
|
|
# Visual Transformer VGF params --------------
|
|
visual_transformer=args.visual_transformer,
|
|
visual_transformer_name=args.visual_trf_name,
|
|
# batch_size=args.batch_size,
|
|
# epochs=args.epochs,
|
|
# lr=args.lr,
|
|
# patience=args.patience,
|
|
# evaluate_step=args.evaluate_step,
|
|
# device="cuda",
|
|
# General params ---------------------
|
|
probabilistic=args.features,
|
|
aggfunc=args.aggfunc,
|
|
optimc=args.optimc,
|
|
load_trained=args.load_trained,
|
|
load_meta=args.meta,
|
|
n_jobs=args.n_jobs,
|
|
)
|
|
|
|
config = gfun.get_config()
|
|
|
|
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
|
|
|
gfun.fit(lX, lY)
|
|
|
|
if args.load_trained is None and not args.nosave:
|
|
gfun.save(save_first_tier=True, save_meta=True)
|
|
|
|
timetr = time()
|
|
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
|
|
|
gfun_preds = gfun.transform(lX_te)
|
|
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
|
|
avg_metrics_gfun, lang_metrics_gfun = log_eval(
|
|
test_eval, phase="test", clf_type=args.clf_type
|
|
)
|
|
|
|
timeval = time()
|
|
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
|
|
|
def log_barplot_wandb(gfun_res, title_affix="per langauge"):
|
|
if title_affix == "per language":
|
|
for metric, lang_values in gfun_res.items():
|
|
data = [[lang, v] for lang, v in lang_values.items()]
|
|
table = wandb.Table(data=data, columns=["lang", f"{metric}"])
|
|
wandb.log(
|
|
{
|
|
f"gFun/language {metric}": wandb.plot.bar(
|
|
table, "lang", metric, title=f"{metric} {title_affix}"
|
|
)
|
|
}
|
|
)
|
|
else:
|
|
data = [[metric, value] for metric, value in gfun_res.items()]
|
|
table = wandb.Table(data=data, columns=["metric", "value"])
|
|
wandb.log(
|
|
{
|
|
f"gFun/average metric": wandb.plot.bar(
|
|
table, "metric", "value", title=f"metric {title_affix}"
|
|
)
|
|
}
|
|
)
|
|
wandb.log(gfun_res)
|
|
|
|
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser()
|
|
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
|
parser.add_argument("--meta", action="store_true")
|
|
parser.add_argument("--nosave", action="store_true")
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
# Dataset parameters -------------------
|
|
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
|
|
parser.add_argument("--domains", type=str, default="all")
|
|
parser.add_argument("--nrows", type=int, default=None)
|
|
parser.add_argument("--min_count", type=int, default=10)
|
|
parser.add_argument("--max_labels", type=int, default=50)
|
|
parser.add_argument("--clf_type", type=str, default="multilabel")
|
|
parser.add_argument("--save_dataset", action="store_true")
|
|
# gFUN parameters ----------------------
|
|
parser.add_argument("-p", "--posteriors", action="store_true")
|
|
parser.add_argument("-m", "--multilingual", action="store_true")
|
|
parser.add_argument("-w", "--wce", action="store_true")
|
|
parser.add_argument("-t", "--textual_transformer", action="store_true")
|
|
parser.add_argument("-v", "--visual_transformer", action="store_true")
|
|
parser.add_argument("--n_jobs", type=int, default=-1)
|
|
parser.add_argument("--optimc", action="store_true")
|
|
parser.add_argument("--features", action="store_false")
|
|
parser.add_argument("--aggfunc", type=str, default="mean")
|
|
# transformer parameters ---------------
|
|
parser.add_argument("--epochs", type=int, default=5)
|
|
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
|
parser.add_argument("--batch_size", type=int, default=32)
|
|
parser.add_argument("--eval_batch_size", type=int, default=128)
|
|
parser.add_argument("--textual_lr", type=float, default=1e-4)
|
|
parser.add_argument("--max_length", type=int, default=128)
|
|
parser.add_argument("--patience", type=int, default=5)
|
|
parser.add_argument("--evaluate_step", type=int, default=10)
|
|
# Visual Transformer parameters --------------
|
|
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
|
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|