updates
This commit is contained in:
parent
86fbd90bd4
commit
317fb93da6
|
|
@ -1,3 +1,5 @@
|
||||||
|
from os.path import expanduser
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
|
@ -15,6 +17,9 @@ import evaluate
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"]
|
||||||
|
RAI_D_COLUMNS = ["id", "lang", "provider", "date", "title", "text", "str_label", "label"]
|
||||||
|
|
||||||
|
|
||||||
def init_callbacks(patience=-1, nosave=False):
|
def init_callbacks(patience=-1, nosave=False):
|
||||||
callbacks = []
|
callbacks = []
|
||||||
|
|
@ -23,7 +28,7 @@ def init_callbacks(patience=-1, nosave=False):
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
def init_model(model_name):
|
def init_model(model_name, nlabels):
|
||||||
if model_name == "mbert":
|
if model_name == "mbert":
|
||||||
hf_name = "bert-base-multilingual-cased"
|
hf_name = "bert-base-multilingual-cased"
|
||||||
elif model_name == "xlm-roberta":
|
elif model_name == "xlm-roberta":
|
||||||
|
|
@ -31,21 +36,35 @@ def init_model(model_name):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
tokenizer = AutoTokenizer.from_pretrained(hf_name)
|
tokenizer = AutoTokenizer.from_pretrained(hf_name)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=3)
|
model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=nlabels)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
tokenizer, model = init_model(args.model)
|
tokenizer, model = init_model(args.model, args.nlabels)
|
||||||
|
|
||||||
|
# data = load_dataset(
|
||||||
|
# "json",
|
||||||
|
# data_files={
|
||||||
|
# "train": "local_datasets/webis-cls/all-domains/train.json",
|
||||||
|
# "test": "local_datasets/webis-cls/all-domains/test.json",
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
|
||||||
data = load_dataset(
|
data = load_dataset(
|
||||||
"json",
|
"csv",
|
||||||
data_files={
|
data_files = {
|
||||||
"train": "local_datasets/webis-cls/all-domains/train.json",
|
# "train": expanduser("~/datasets/rai/csv/rai-no-it-train.csv"),
|
||||||
"test": "local_datasets/webis-cls/all-domains/test.json",
|
# "test": expanduser("~/datasets/rai/csv/rai-no-it-test.csv")
|
||||||
},
|
# "train": expanduser("~/datasets/rai/csv/rai-train.csv"),
|
||||||
|
# "test": expanduser("~/datasets/rai/csv/rai-test-ita-labeled.csv")
|
||||||
|
"train": expanduser("~/datasets/rai/csv/train-split-rai.csv"),
|
||||||
|
"test": expanduser("~/datasets/rai/csv/test-split-rai-labeled.csv")
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_sample(sample):
|
|
||||||
|
def process_sample_iwslt(sample):
|
||||||
inputs = sample["text"]
|
inputs = sample["text"]
|
||||||
ratings = [r - 1 for r in sample["rating"]]
|
ratings = [r - 1 for r in sample["rating"]]
|
||||||
targets = torch.zeros((len(inputs), 3), dtype=float)
|
targets = torch.zeros((len(inputs), 3), dtype=float)
|
||||||
|
|
@ -56,17 +75,26 @@ def main(args):
|
||||||
for i, r in enumerate(ratings):
|
for i, r in enumerate(ratings):
|
||||||
targets[i][r - 1] = 1
|
targets[i][r - 1] = 1
|
||||||
|
|
||||||
model_inputs = tokenizer(inputs, max_length=512, truncation=True)
|
model_inputs = tokenizer(inputs, max_length=128, truncation=True)
|
||||||
model_inputs["labels"] = targets
|
model_inputs["labels"] = targets
|
||||||
model_inputs["lang_ids"] = torch.tensor(lang_ids)
|
model_inputs["lang_ids"] = torch.tensor(lang_ids)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def process_sample_rai(sample):
|
||||||
|
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
|
||||||
|
labels = sample["label"]
|
||||||
|
model_inputs = tokenizer(inputs, max_length=512, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
|
||||||
|
model_inputs["labels"] = labels
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
data = data.map(
|
data = data.map(
|
||||||
process_sample,
|
process_sample_rai,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=4,
|
num_proc=4,
|
||||||
load_from_cache_file=True,
|
load_from_cache_file=True,
|
||||||
remove_columns=["text", "category", "rating", "summary", "title"],
|
remove_columns=RAI_D_COLUMNS,
|
||||||
)
|
)
|
||||||
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
||||||
data.set_format("torch")
|
data.set_format("torch")
|
||||||
|
|
@ -87,7 +115,7 @@ def main(args):
|
||||||
recall_metric = evaluate.load("recall")
|
recall_metric = evaluate.load("recall")
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=f"{args.model}-sentiment",
|
output_dir=f"{args.model}-rai-final",
|
||||||
do_train=True,
|
do_train=True,
|
||||||
evaluation_strategy="steps",
|
evaluation_strategy="steps",
|
||||||
per_device_train_batch_size=args.batch,
|
per_device_train_batch_size=args.batch,
|
||||||
|
|
@ -99,7 +127,8 @@ def main(args):
|
||||||
max_grad_norm=5.0,
|
max_grad_norm=5.0,
|
||||||
num_train_epochs=args.epochs,
|
num_train_epochs=args.epochs,
|
||||||
lr_scheduler_type=args.scheduler,
|
lr_scheduler_type=args.scheduler,
|
||||||
warmup_steps=1000,
|
# warmup_ratio=0.1,
|
||||||
|
warmup_ratio=1500,
|
||||||
logging_strategy="steps",
|
logging_strategy="steps",
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
logging_steps=args.steplog,
|
logging_steps=args.steplog,
|
||||||
|
|
@ -109,7 +138,7 @@ def main(args):
|
||||||
save_strategy="no" if args.nosave else "steps",
|
save_strategy="no" if args.nosave else "steps",
|
||||||
save_total_limit=3,
|
save_total_limit=3,
|
||||||
eval_steps=args.stepeval,
|
eval_steps=args.stepeval,
|
||||||
run_name=f"{args.model}-sentiment-run",
|
run_name=f"{args.model}-rai-stratified",
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
log_level="warning",
|
log_level="warning",
|
||||||
report_to=["wandb"] if args.wandb else "none",
|
report_to=["wandb"] if args.wandb else "none",
|
||||||
|
|
@ -119,7 +148,8 @@ def main(args):
|
||||||
|
|
||||||
def compute_metrics(eval_preds):
|
def compute_metrics(eval_preds):
|
||||||
preds = eval_preds.predictions.argmax(-1)
|
preds = eval_preds.predictions.argmax(-1)
|
||||||
targets = eval_preds.label_ids.argmax(-1)
|
# targets = eval_preds.label_ids.argmax(-1)
|
||||||
|
targets = eval_preds.label_ids
|
||||||
setting = "macro"
|
setting = "macro"
|
||||||
f1_score_macro = f1_metric.compute(
|
f1_score_macro = f1_metric.compute(
|
||||||
predictions=preds, references=targets, average="macro"
|
predictions=preds, references=targets, average="macro"
|
||||||
|
|
@ -146,7 +176,7 @@ def main(args):
|
||||||
|
|
||||||
if args.wandb:
|
if args.wandb:
|
||||||
import wandb
|
import wandb
|
||||||
wandb.init(entity="andreapdr", project=f"gfun-senti-hf", name="mbert-sent", config=vars(args))
|
wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-sent", config=vars(args))
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -162,7 +192,6 @@ def main(args):
|
||||||
print("- Training:")
|
print("- Training:")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
print("- Testing:")
|
print("- Testing:")
|
||||||
test_results = trainer.evaluate(eval_dataset=data["test"])
|
test_results = trainer.evaluate(eval_dataset=data["test"])
|
||||||
print(test_results)
|
print(test_results)
|
||||||
|
|
@ -174,13 +203,14 @@ if __name__ == "__main__":
|
||||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("--model", type=str, metavar="", default="mbert")
|
parser.add_argument("--model", type=str, metavar="", default="mbert")
|
||||||
|
parser.add_argument("--nlabels", type=int, metavar="", default=3)
|
||||||
parser.add_argument("--lr", type=float, metavar="", default=1e-5, help="Set learning rate",)
|
parser.add_argument("--lr", type=float, metavar="", default=1e-5, help="Set learning rate",)
|
||||||
parser.add_argument("--scheduler", type=str, metavar="", default="linear", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
parser.add_argument("--scheduler", type=str, metavar="", default="linear", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
||||||
parser.add_argument("--batch", type=int, metavar="", default=16, help="Set batch size")
|
parser.add_argument("--batch", type=int, metavar="", default=16, help="Set batch size")
|
||||||
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
||||||
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
|
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
|
||||||
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
|
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
|
||||||
parser.add_argument("--steplog", type=int, metavar="", default=100, help="Log training every n steps")
|
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
|
||||||
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
|
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
|
||||||
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
|
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
|
||||||
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
|
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
|
||||||
|
|
|
||||||
8
main.py
8
main.py
|
|
@ -1,5 +1,3 @@
|
||||||
import wandb
|
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
@ -11,7 +9,6 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
TODO:
|
TODO:
|
||||||
- General:
|
- General:
|
||||||
[!] zero-shot setup
|
[!] zero-shot setup
|
||||||
- CLS dataset is loading only "books" domain data
|
|
||||||
- Docs:
|
- Docs:
|
||||||
- add documentations sphinx
|
- add documentations sphinx
|
||||||
"""
|
"""
|
||||||
|
|
@ -96,6 +93,8 @@ def main(args):
|
||||||
|
|
||||||
config = gfun.get_config()
|
config = gfun.get_config()
|
||||||
|
|
||||||
|
if args.wandb:
|
||||||
|
import wandb
|
||||||
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
||||||
|
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
@ -139,6 +138,7 @@ def main(args):
|
||||||
)
|
)
|
||||||
wandb.log(gfun_res)
|
wandb.log(gfun_res)
|
||||||
|
|
||||||
|
if args.wandb:
|
||||||
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -178,6 +178,8 @@ if __name__ == "__main__":
|
||||||
# Visual Transformer parameters --------------
|
# Visual Transformer parameters --------------
|
||||||
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
||||||
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
||||||
|
# logging
|
||||||
|
parser.add_argument("--wandb", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue