from os.path import expanduser import torch from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, ) from gfun.vgfs.commons import Trainer from datasets import load_dataset, DatasetDict from transformers import Trainer from pprint import pprint import transformers import evaluate import pandas as pd transformers.logging.set_verbosity_error() IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"] RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] # "lang" WEBIS_D_COLUMNS = ['Unnamed: 0', 'asin', 'category', 'original_rating', 'label', 'title', 'text', 'summary'] # "lang" MAX_LEN = 128 # DATASET_NAME = "rai" # DATASET_NAME = "rai-multilingual-2000" # DATASET_NAME = "webis-cls" def init_callbacks(patience=-1, nosave=False): callbacks = [] if patience != -1 and not nosave: callbacks.append(transformers.EarlyStoppingCallback(early_stopping_patience=patience)) return callbacks def init_model(model_name, nlabels): if model_name == "mbert": # hf_name = "bert-base-multilingual-cased" hf_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600" # hf_name = "hf_models/mbert-rai-fewshot-second/checkpoint-9000" elif model_name == "xlm-roberta": hf_name = "xlm-roberta-base" else: raise NotImplementedError tokenizer = AutoTokenizer.from_pretrained(hf_name) model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=nlabels) return tokenizer, model def main(args): tokenizer, model = init_model(args.model, args.nlabels) data = load_dataset( "csv", data_files = { "train": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/train.balanced.csv"), "test": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/test.balanced.csv") # "train": expanduser(f"~/datasets/rai/csv/train-{DATASET_NAME}.csv"), # "test": expanduser(f"~/datasets/rai/csv/test-{DATASET_NAME}.csv") # "train": expanduser(f"~/datasets/rai/csv/train.small.csv"), # "test": expanduser(f"~/datasets/rai/csv/test.small.csv") } ) def process_sample_rai(sample): inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])] labels = sample["label"] model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there... model_inputs["labels"] = labels return model_inputs def process_sample_webis(sample): inputs = sample["text"] labels = sample["label"] model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there... model_inputs["labels"] = labels return model_inputs data = data.map( # process_sample_rai, process_sample_webis, batched=True, num_proc=4, load_from_cache_file=True, # remove_columns=RAI_D_COLUMNS, remove_columns=WEBIS_D_COLUMNS, ) train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42) data.set_format("torch") data = DatasetDict( { "train": train_val_splits["train"], "validation": train_val_splits["test"], "test": data["test"], } ) data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) callbacks = init_callbacks(args.patience, args.nosave) f1_metric = evaluate.load("f1") accuracy_metric = evaluate.load("accuracy") precision_metric = evaluate.load("precision") recall_metric = evaluate.load("recall") training_args = TrainingArguments( # output_dir=f"hf_models/{args.model}-rai", output_dir=f"hf_models/{args.model}-sentiment-balanced", do_train=True, evaluation_strategy="steps", per_device_train_batch_size=args.batch, per_device_eval_batch_size=args.batch, gradient_accumulation_steps=args.gradacc, eval_accumulation_steps=10, learning_rate=args.lr, weight_decay=0.1, max_grad_norm=5.0, num_train_epochs=args.epochs, lr_scheduler_type=args.scheduler, warmup_ratio=0.01, logging_strategy="steps", logging_first_step=True, logging_steps=args.steplog, seed=42, fp16=args.fp16, load_best_model_at_end=False if args.nosave else True, save_strategy="no" if args.nosave else "steps", save_total_limit=2, eval_steps=args.stepeval, # run_name=f"{args.model}-rai-stratified", run_name=f"{args.model}-sentiment", disable_tqdm=False, log_level="warning", report_to=["wandb"] if args.wandb else "none", optim="adamw_torch", save_steps=args.stepeval ) def compute_metrics(eval_preds): preds = eval_preds.predictions.argmax(-1) # targets = eval_preds.label_ids.argmax(-1) targets = eval_preds.label_ids setting = "macro" f1_score_macro = f1_metric.compute( predictions=preds, references=targets, average="macro" ) f1_score_micro = f1_metric.compute( predictions=preds, references=targets, average="micro" ) accuracy_score = accuracy_metric.compute(predictions=preds, references=targets) precision_score = precision_metric.compute( predictions=preds, references=targets, average=setting, zero_division=1 ) recall_score = recall_metric.compute( predictions=preds, references=targets, average=setting, zero_division=1 ) results = { "macro_f1score": f1_score_macro["f1"], "micro_f1score": f1_score_micro["f1"], "accuracy": accuracy_score["accuracy"], "precision": precision_score["precision"], "recall": recall_score["recall"], } results = {k: round(v, 4) for k, v in results.items()} return results if args.wandb: import wandb wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-rai", config=vars(args)) trainer = Trainer( model=model, args=training_args, train_dataset=data["train"], eval_dataset=data["validation"], compute_metrics=compute_metrics, tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, ) if not args.onlytest: print("- Training:") trainer.train() print("- Testing:") test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test") pprint(test_results.metrics) save_preds(data["test"], test_results.predictions) exit() def save_preds(dataset, predictions): df = pd.DataFrame() df["langs"] = dataset["lang"] df["labels"] = dataset["labels"] df["preds"] = predictions.argmax(axis=1) df.to_csv("results/lang-specific.mbert.webis.csv", index=False) return if __name__ == "__main__": from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument("--model", type=str, metavar="", default="mbert") parser.add_argument("--nlabels", type=int, metavar="", default=28) parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",) parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]") parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size") parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps") parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs") parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps") parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps") parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience") parser.add_argument("--fp16", action="store_true", help="Use fp16 precision") parser.add_argument("--wandb", action="store_true", help="Log to wandb") parser.add_argument("--nosave", action="store_true", help="Avoid saving model") parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set") # parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data") args = parser.parse_args() main(args)