set test key_prefix in test phase for wandb
This commit is contained in:
parent
8354d76513
commit
b6b1d33fdb
|
@ -47,8 +47,8 @@ def main(args):
|
||||||
data = load_dataset(
|
data = load_dataset(
|
||||||
"csv",
|
"csv",
|
||||||
data_files = {
|
data_files = {
|
||||||
"train": expanduser("~/datasets/rai/csv/train-rai-multilingual-2000.csv"),
|
"train": expanduser("~/datasets/rai/csv/train-split-rai.csv"),
|
||||||
"test": expanduser("~/datasets/rai/csv/test-rai-multilingual-2000.csv")
|
"test": expanduser("~/datasets/rai/csv/test-split-rai.csv")
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ def main(args):
|
||||||
process_sample_rai,
|
process_sample_rai,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=4,
|
num_proc=4,
|
||||||
load_from_cache_file=False,
|
load_from_cache_file=True,
|
||||||
remove_columns=RAI_D_COLUMNS,
|
remove_columns=RAI_D_COLUMNS,
|
||||||
)
|
)
|
||||||
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
||||||
|
@ -103,7 +103,7 @@ def main(args):
|
||||||
recall_metric = evaluate.load("recall")
|
recall_metric = evaluate.load("recall")
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=f"{args.model}-rai-multi-2000",
|
output_dir=f"hf_models/{args.model}-rai-fewshot",
|
||||||
do_train=True,
|
do_train=True,
|
||||||
evaluation_strategy="steps",
|
evaluation_strategy="steps",
|
||||||
per_device_train_batch_size=args.batch,
|
per_device_train_batch_size=args.batch,
|
||||||
|
@ -123,13 +123,14 @@ def main(args):
|
||||||
fp16=args.fp16,
|
fp16=args.fp16,
|
||||||
load_best_model_at_end=False if args.nosave else True,
|
load_best_model_at_end=False if args.nosave else True,
|
||||||
save_strategy="no" if args.nosave else "steps",
|
save_strategy="no" if args.nosave else "steps",
|
||||||
save_total_limit=3,
|
save_total_limit=2,
|
||||||
eval_steps=args.stepeval,
|
eval_steps=args.stepeval,
|
||||||
run_name=f"{args.model}-rai-stratified",
|
run_name=f"{args.model}-rai-stratified",
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
log_level="warning",
|
log_level="warning",
|
||||||
report_to=["wandb"] if args.wandb else "none",
|
report_to=["wandb"] if args.wandb else "none",
|
||||||
optim="adamw_torch",
|
optim="adamw_torch",
|
||||||
|
save_steps=args.stepeval
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,7 +164,7 @@ def main(args):
|
||||||
|
|
||||||
if args.wandb:
|
if args.wandb:
|
||||||
import wandb
|
import wandb
|
||||||
wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-sent", config=vars(args))
|
wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-rai", config=vars(args))
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -176,11 +177,11 @@ def main(args):
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print("- Training:")
|
print("- Training:")
|
||||||
# trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
print("- Testing:")
|
print("- Testing:")
|
||||||
test_results = trainer.evaluate(eval_dataset=data["test"])
|
test_results = trainer.evaluate(eval_dataset=data["test"], metric_key_prefix="test")
|
||||||
print(test_results)
|
print(test_results)
|
||||||
|
|
||||||
exit()
|
exit()
|
||||||
|
@ -191,8 +192,8 @@ if __name__ == "__main__":
|
||||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("--model", type=str, metavar="", default="mbert")
|
parser.add_argument("--model", type=str, metavar="", default="mbert")
|
||||||
parser.add_argument("--nlabels", type=int, metavar="", default=28)
|
parser.add_argument("--nlabels", type=int, metavar="", default=28)
|
||||||
parser.add_argument("--lr", type=float, metavar="", default=1e-5, help="Set learning rate",)
|
parser.add_argument("--lr", type=float, metavar="", default=1e-4, help="Set learning rate",)
|
||||||
parser.add_argument("--scheduler", type=str, metavar="", default="linear", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
||||||
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
|
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
|
||||||
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
||||||
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
|
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
|
||||||
|
|
Loading…
Reference in New Issue