228 lines
8.9 KiB
Python
228 lines
8.9 KiB
Python
from os.path import expanduser, join
|
|
|
|
import torch
|
|
from transformers import (
|
|
AutoModelForSequenceClassification,
|
|
AutoTokenizer,
|
|
DataCollatorWithPadding,
|
|
TrainingArguments,
|
|
)
|
|
from gfun.vgfs.commons import Trainer
|
|
from datasets import load_dataset, DatasetDict
|
|
|
|
from transformers import Trainer
|
|
from pprint import pprint
|
|
|
|
import transformers
|
|
import evaluate
|
|
import pandas as pd
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"]
|
|
MAX_LEN = 128
|
|
|
|
|
|
def init_callbacks(patience=-1, nosave=False):
|
|
callbacks = []
|
|
if patience != -1 and not nosave:
|
|
callbacks.append(transformers.EarlyStoppingCallback(early_stopping_patience=patience))
|
|
return callbacks
|
|
|
|
|
|
def init_model(model_name, nlabels, saved_model=None):
|
|
if model_name == "mbert":
|
|
if saved_model is None:
|
|
hf_name = "bert-base-multilingual-cased"
|
|
else:
|
|
hf_name = saved_model
|
|
elif model_name == "xlm-roberta":
|
|
if saved_model is None:
|
|
hf_name = "xlm-roberta-base"
|
|
else:
|
|
hf_name = saved_model
|
|
else:
|
|
raise NotImplementedError
|
|
tokenizer = AutoTokenizer.from_pretrained(hf_name)
|
|
model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=nlabels)
|
|
return tokenizer, model
|
|
|
|
|
|
def main(args):
|
|
saved_model = args.savedmodel
|
|
trainlang = args.trainlangs
|
|
datapath = args.datapath
|
|
|
|
tokenizer, model = init_model(args.model, args.nlabels, saved_model=saved_model)
|
|
|
|
data = load_dataset(
|
|
"csv",
|
|
data_files = {
|
|
"train": expanduser(join(datapath, "train.csv")),
|
|
"test": expanduser(join(datapath, "test.small.csv"))
|
|
}
|
|
)
|
|
|
|
def filter_dataset(dataset, lang):
|
|
indices = [i for i, l in enumerate(dataset["lang"]) if l == lang]
|
|
dataset = dataset.select(indices)
|
|
return dataset
|
|
|
|
if trainlang is not None:
|
|
data["train"] = filter_dataset(data["train"], lang=trainlang)
|
|
|
|
def process_sample_rai(sample):
|
|
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
|
|
labels = sample["label"]
|
|
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True)
|
|
model_inputs["labels"] = labels
|
|
return model_inputs
|
|
|
|
data = data.map(
|
|
process_sample_rai,
|
|
batched=True,
|
|
num_proc=4,
|
|
load_from_cache_file=True,
|
|
remove_columns=RAI_D_COLUMNS,
|
|
)
|
|
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
|
data.set_format("torch")
|
|
data = DatasetDict(
|
|
{
|
|
"train": train_val_splits["train"],
|
|
"validation": train_val_splits["test"],
|
|
"test": data["test"],
|
|
}
|
|
)
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
|
callbacks = init_callbacks(args.patience, args.nosave)
|
|
|
|
f1_metric = evaluate.load("f1")
|
|
accuracy_metric = evaluate.load("accuracy")
|
|
precision_metric = evaluate.load("precision")
|
|
recall_metric = evaluate.load("recall")
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=f"hf_models/{args.model}-fewshot-full" if trainlang is None else f"hf_models/{args.model}-zeroshot-full",
|
|
run_name="model-zeroshot" if trainlang is not None else "model-fewshot",
|
|
do_train=True,
|
|
evaluation_strategy="steps",
|
|
per_device_train_batch_size=args.batch,
|
|
per_device_eval_batch_size=args.batch,
|
|
gradient_accumulation_steps=args.gradacc,
|
|
eval_accumulation_steps=10,
|
|
learning_rate=args.lr,
|
|
weight_decay=0.1,
|
|
max_grad_norm=5.0,
|
|
num_train_epochs=args.epochs,
|
|
lr_scheduler_type=args.scheduler,
|
|
warmup_ratio=0.01,
|
|
logging_strategy="steps",
|
|
logging_first_step=True,
|
|
logging_steps=args.steplog,
|
|
seed=42,
|
|
fp16=args.fp16,
|
|
load_best_model_at_end=False if args.nosave else True,
|
|
save_strategy="no" if args.nosave else "steps",
|
|
save_total_limit=2,
|
|
eval_steps=args.stepeval,
|
|
disable_tqdm=False,
|
|
log_level="warning",
|
|
report_to=["wandb"] if args.wandb else "none",
|
|
optim="adamw_torch",
|
|
save_steps=args.stepeval
|
|
)
|
|
|
|
|
|
def compute_metrics(eval_preds):
|
|
preds = eval_preds.predictions.argmax(-1)
|
|
targets = eval_preds.label_ids
|
|
setting = "macro"
|
|
f1_score_macro = f1_metric.compute(
|
|
predictions=preds, references=targets, average="macro"
|
|
)
|
|
f1_score_micro = f1_metric.compute(
|
|
predictions=preds, references=targets, average="micro"
|
|
)
|
|
accuracy_score = accuracy_metric.compute(predictions=preds, references=targets)
|
|
precision_score = precision_metric.compute(
|
|
predictions=preds, references=targets, average=setting, zero_division=1
|
|
)
|
|
recall_score = recall_metric.compute(
|
|
predictions=preds, references=targets, average=setting, zero_division=1
|
|
)
|
|
results = {
|
|
"macro_f1score": f1_score_macro["f1"],
|
|
"micro_f1score": f1_score_micro["f1"],
|
|
"accuracy": accuracy_score["accuracy"],
|
|
"precision": precision_score["precision"],
|
|
"recall": recall_score["recall"],
|
|
}
|
|
results = {k: round(v, 4) for k, v in results.items()}
|
|
return results
|
|
|
|
if args.wandb:
|
|
import wandb
|
|
wandb.init(entity="andreapdr", project=f"gfun",
|
|
name="model-zeroshot-full" if trainlang is not None else "model-fewshot-full",
|
|
config=vars(args))
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=data["train"],
|
|
eval_dataset=data["validation"],
|
|
compute_metrics=compute_metrics,
|
|
tokenizer=tokenizer,
|
|
data_collator=data_collator,
|
|
callbacks=callbacks,
|
|
)
|
|
|
|
if not args.onlytest:
|
|
print("- Training:")
|
|
trainer.train()
|
|
|
|
print("- Testing:")
|
|
test_results = trainer.evaluate(eval_dataset=data["test"], metric_key_prefix="test")
|
|
test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test")
|
|
pprint(test_results.metrics)
|
|
save_preds(data["test"], test_results.predictions, trainlang)
|
|
exit()
|
|
|
|
def save_preds(dataset, predictions, trainlang=None):
|
|
df = pd.DataFrame()
|
|
df["langs"] = dataset["lang"]
|
|
df["labels"] = dataset["labels"]
|
|
df["preds"] = predictions.argmax(axis=1)
|
|
if trainlang is not None:
|
|
df.to_csv(f"results/zeroshot.{trainlang}.model.csv", index=False)
|
|
else:
|
|
df.to_csv("results/fewshot.model.csv", index=False)
|
|
return
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument("--model", type=str, metavar="", default="mbert")
|
|
parser.add_argument("--nlabels", type=int, metavar="", default=28)
|
|
parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",)
|
|
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
|
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
|
|
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
|
parser.add_argument("--epochs", type=int, metavar="", default=10, help="Set epochs")
|
|
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
|
|
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
|
|
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
|
|
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
|
|
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
|
|
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
|
|
parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
|
|
parser.add_argument("--trainlang", default=None, type=str, help="set training language for zero-shot experiments" )
|
|
parser.add_argument("--datapath", type=str, default="data", help="path to the csv dataset. Dir should contain both a train.csv and a test.csv file")
|
|
parser.add_argument("--savedmodel", type=str, default="hf_models/mbert-rai-fewshot-second/checkpoint-9000")
|
|
args = parser.parse_args()
|
|
main(args)
|