transformer-trainer via huggingface

This commit is contained in:
Andrea Pedrotti 2023-06-22 11:33:06 +02:00
parent 60171c1b5e
commit e3e6f061d8
1 changed files with 191 additions and 0 deletions

191
hf_trainer.py Normal file
View File

@ -0,0 +1,191 @@
import torch
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
TrainingArguments,
)
from gfun.vgfs.commons import Trainer
from datasets import load_dataset, DatasetDict
from transformers import Trainer
import transformers
import evaluate
transformers.logging.set_verbosity_error()
def init_callbacks(patience=-1, nosave=False):
callbacks = []
if patience != -1 and not nosave:
callbacks.append(transformers.EarlyStoppingCallback(early_stopping_patience=patience))
return callbacks
def init_model(model_name):
if model_name == "mbert":
hf_name = "bert-base-multilingual-cased"
elif model_name == "xlm-roberta":
hf_name = "xlm-roberta-base"
else:
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained(hf_name)
model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=3)
return tokenizer, model
def main(args):
tokenizer, model = init_model(args.model)
data = load_dataset(
"json",
data_files={
"train": "local_datasets/webis-cls/all-domains/train.json",
"test": "local_datasets/webis-cls/all-domains/test.json",
},
)
def process_sample(sample):
inputs = sample["text"]
ratings = [r - 1 for r in sample["rating"]]
targets = torch.zeros((len(inputs), 3), dtype=float)
lang_mapper = {
lang: lang_id for lang_id, lang in enumerate(set(sample["lang"]))
}
lang_ids = [lang_mapper[l] for l in sample["lang"]]
for i, r in enumerate(ratings):
targets[i][r - 1] = 1
model_inputs = tokenizer(inputs, max_length=512, truncation=True)
model_inputs["labels"] = targets
model_inputs["lang_ids"] = torch.tensor(lang_ids)
return model_inputs
data = data.map(
process_sample,
batched=True,
num_proc=4,
load_from_cache_file=True,
remove_columns=["text", "category", "rating", "summary", "title"],
)
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
data.set_format("torch")
data = DatasetDict(
{
"train": train_val_splits["train"],
"validation": train_val_splits["test"],
"test": data["test"],
}
)
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
callbacks = init_callbacks(args.patience, args.nosave)
f1_metric = evaluate.load("f1")
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
training_args = TrainingArguments(
output_dir=f"{args.model}-sentiment",
do_train=True,
evaluation_strategy="steps",
per_device_train_batch_size=args.batch,
per_device_eval_batch_size=args.batch,
gradient_accumulation_steps=args.gradacc,
eval_accumulation_steps=10,
learning_rate=args.lr,
weight_decay=0.1,
max_grad_norm=5.0,
num_train_epochs=args.epochs,
lr_scheduler_type=args.scheduler,
warmup_steps=1000,
logging_strategy="steps",
logging_first_step=True,
logging_steps=args.steplog,
seed=42,
fp16=args.fp16,
load_best_model_at_end=False if args.nosave else True,
save_strategy="no" if args.nosave else "steps",
save_total_limit=3,
eval_steps=args.stepeval,
run_name=f"{args.model}-sentiment-run",
disable_tqdm=False,
log_level="warning",
report_to=["wandb"] if args.wandb else "none",
optim="adamw_torch",
)
def compute_metrics(eval_preds):
preds = eval_preds.predictions.argmax(-1)
targets = eval_preds.label_ids.argmax(-1)
setting = "macro"
f1_score_macro = f1_metric.compute(
predictions=preds, references=targets, average="macro"
)
f1_score_micro = f1_metric.compute(
predictions=preds, references=targets, average="micro"
)
accuracy_score = accuracy_metric.compute(predictions=preds, references=targets)
precision_score = precision_metric.compute(
predictions=preds, references=targets, average=setting, zero_division=1
)
recall_score = recall_metric.compute(
predictions=preds, references=targets, average=setting, zero_division=1
)
results = {
"macro_f1score": f1_score_macro["f1"],
"micro_f1score": f1_score_micro["f1"],
"accuracy": accuracy_score["accuracy"],
"precision": precision_score["precision"],
"recall": recall_score["recall"],
}
results = {k: round(v, 4) for k, v in results.items()}
return results
if args.wandb:
import wandb
wandb.init(entity="andreapdr", project=f"gfun-senti-hf", name="mbert-sent", config=vars(args))
trainer = Trainer(
model=model,
args=training_args,
train_dataset=data["train"],
eval_dataset=data["validation"],
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
)
print("- Training:")
trainer.train()
print("- Testing:")
test_results = trainer.evaluate(eval_dataset=data["test"])
print(test_results)
exit()
if __name__ == "__main__":
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, metavar="", default="mbert")
parser.add_argument("--lr", type=float, metavar="", default=1e-5, help="Set learning rate",)
parser.add_argument("--scheduler", type=str, metavar="", default="linear", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
parser.add_argument("--batch", type=int, metavar="", default=16, help="Set batch size")
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
parser.add_argument("--steplog", type=int, metavar="", default=100, help="Log training every n steps")
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
# parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
args = parser.parse_args()
main(args)