From e3e6f061d8545d549d59e655552b15fc53de5f06 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Thu, 22 Jun 2023 11:33:06 +0200 Subject: [PATCH] transformer-trainer via huggingface --- hf_trainer.py | 191 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 hf_trainer.py diff --git a/hf_trainer.py b/hf_trainer.py new file mode 100644 index 0000000..3d37708 --- /dev/null +++ b/hf_trainer.py @@ -0,0 +1,191 @@ +import torch +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + TrainingArguments, +) +from gfun.vgfs.commons import Trainer +from datasets import load_dataset, DatasetDict + +from transformers import Trainer + +import transformers +import evaluate + +transformers.logging.set_verbosity_error() + + +def init_callbacks(patience=-1, nosave=False): + callbacks = [] + if patience != -1 and not nosave: + callbacks.append(transformers.EarlyStoppingCallback(early_stopping_patience=patience)) + return callbacks + + +def init_model(model_name): + if model_name == "mbert": + hf_name = "bert-base-multilingual-cased" + elif model_name == "xlm-roberta": + hf_name = "xlm-roberta-base" + else: + raise NotImplementedError + tokenizer = AutoTokenizer.from_pretrained(hf_name) + model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=3) + return tokenizer, model + +def main(args): + tokenizer, model = init_model(args.model) + + data = load_dataset( + "json", + data_files={ + "train": "local_datasets/webis-cls/all-domains/train.json", + "test": "local_datasets/webis-cls/all-domains/test.json", + }, + ) + + def process_sample(sample): + inputs = sample["text"] + ratings = [r - 1 for r in sample["rating"]] + targets = torch.zeros((len(inputs), 3), dtype=float) + lang_mapper = { + lang: lang_id for lang_id, lang in enumerate(set(sample["lang"])) + } + lang_ids = [lang_mapper[l] for l in sample["lang"]] + for i, r in enumerate(ratings): + targets[i][r - 1] = 1 + + model_inputs = tokenizer(inputs, max_length=512, truncation=True) + model_inputs["labels"] = targets + model_inputs["lang_ids"] = torch.tensor(lang_ids) + return model_inputs + + data = data.map( + process_sample, + batched=True, + num_proc=4, + load_from_cache_file=True, + remove_columns=["text", "category", "rating", "summary", "title"], + ) + train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42) + data.set_format("torch") + data = DatasetDict( + { + "train": train_val_splits["train"], + "validation": train_val_splits["test"], + "test": data["test"], + } + ) + + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + callbacks = init_callbacks(args.patience, args.nosave) + + f1_metric = evaluate.load("f1") + accuracy_metric = evaluate.load("accuracy") + precision_metric = evaluate.load("precision") + recall_metric = evaluate.load("recall") + + training_args = TrainingArguments( + output_dir=f"{args.model}-sentiment", + do_train=True, + evaluation_strategy="steps", + per_device_train_batch_size=args.batch, + per_device_eval_batch_size=args.batch, + gradient_accumulation_steps=args.gradacc, + eval_accumulation_steps=10, + learning_rate=args.lr, + weight_decay=0.1, + max_grad_norm=5.0, + num_train_epochs=args.epochs, + lr_scheduler_type=args.scheduler, + warmup_steps=1000, + logging_strategy="steps", + logging_first_step=True, + logging_steps=args.steplog, + seed=42, + fp16=args.fp16, + load_best_model_at_end=False if args.nosave else True, + save_strategy="no" if args.nosave else "steps", + save_total_limit=3, + eval_steps=args.stepeval, + run_name=f"{args.model}-sentiment-run", + disable_tqdm=False, + log_level="warning", + report_to=["wandb"] if args.wandb else "none", + optim="adamw_torch", + ) + + + def compute_metrics(eval_preds): + preds = eval_preds.predictions.argmax(-1) + targets = eval_preds.label_ids.argmax(-1) + setting = "macro" + f1_score_macro = f1_metric.compute( + predictions=preds, references=targets, average="macro" + ) + f1_score_micro = f1_metric.compute( + predictions=preds, references=targets, average="micro" + ) + accuracy_score = accuracy_metric.compute(predictions=preds, references=targets) + precision_score = precision_metric.compute( + predictions=preds, references=targets, average=setting, zero_division=1 + ) + recall_score = recall_metric.compute( + predictions=preds, references=targets, average=setting, zero_division=1 + ) + results = { + "macro_f1score": f1_score_macro["f1"], + "micro_f1score": f1_score_micro["f1"], + "accuracy": accuracy_score["accuracy"], + "precision": precision_score["precision"], + "recall": recall_score["recall"], + } + results = {k: round(v, 4) for k, v in results.items()} + return results + + if args.wandb: + import wandb + wandb.init(entity="andreapdr", project=f"gfun-senti-hf", name="mbert-sent", config=vars(args)) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=data["train"], + eval_dataset=data["validation"], + compute_metrics=compute_metrics, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + ) + + print("- Training:") + trainer.train() + + + print("- Testing:") + test_results = trainer.evaluate(eval_dataset=data["test"]) + print(test_results) + + exit() + + +if __name__ == "__main__": + from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument("--model", type=str, metavar="", default="mbert") + parser.add_argument("--lr", type=float, metavar="", default=1e-5, help="Set learning rate",) + parser.add_argument("--scheduler", type=str, metavar="", default="linear", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]") + parser.add_argument("--batch", type=int, metavar="", default=16, help="Set batch size") + parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps") + parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs") + parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps") + parser.add_argument("--steplog", type=int, metavar="", default=100, help="Log training every n steps") + parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience") + parser.add_argument("--fp16", action="store_true", help="Use fp16 precision") + parser.add_argument("--wandb", action="store_true", help="Log to wandb") + parser.add_argument("--nosave", action="store_true", help="Avoid saving model") + # parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set") + # parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data") + args = parser.parse_args() + main(args)