From 31fc225d0eee81691cbc266edeabfb65cece5a2b Mon Sep 17 00:00:00 2001 From: andreapdr Date: Tue, 12 Mar 2024 13:26:37 +0100 Subject: [PATCH] renamed arguments --- infer.py | 18 ++++++++++++------ run-rai.sh | 25 ------------------------- 2 files changed, 12 insertions(+), 31 deletions(-) delete mode 100644 run-rai.sh diff --git a/infer.py b/infer.py index 65dd0c3..58075cc 100644 --- a/infer.py +++ b/infer.py @@ -36,10 +36,11 @@ def main(args): doc_ids=ids, dataset_name=dataset.get_name(), targets=None, - category_mapper="models/category_mappers/rai-mapping.csv" + category_mapper=args.category_map, + outdir=args.outdir, ) -def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, output_dir="results/inference-preds"): +def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, outdir="results/inference-preds"): """ Parameters ---------- @@ -61,7 +62,7 @@ def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_ma Dir where to store output csv file. (default = results/inference-preds) """ - os.makedirs(output_dir, exist_ok=True) + os.makedirs(outdir, exist_ok=True) df = pd.DataFrame() langs = sorted(preds.keys()) _ids = [] @@ -86,18 +87,23 @@ def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_ma mapper = pd.read_csv(category_mapper).to_dict()["category"] df["gfun_string_prediction"] = [mapper[p] for p in _preds] - output_file = f"{output_dir}/{dataset_name}.csv" + output_file = f"{outdir}/{dataset_name}.csv" print(f"Storing predicitons in: {output_file}") df.to_csv(output_file, index=False) + + return if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() - parser.add_argument("--datapath", type=str, default="~/datasets/rai/csv", help="Path to csv file containing the documents to be classified") + parser.add_argument("--datapath", required=True, type=str, help="path to csv file containing the documents to be classified") + parser.add_argument("--outdir", type=str, default="results/inference-preds", help="path to store csv file containing gfun predictions") + parser.add_argument("--category_map", type=str, default="models/category_mappers/rai-mapping.csv", help="path to csv file containing the mapping from label name to label id [str: id]") parser.add_argument("--nlabels", type=int, default=28) - parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings") + parser.add_argument("--muse_dir", type=str, default="embeddings", help="path to muse embeddings") parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance") + args = parser.parse_args() main(args) diff --git a/run-rai.sh b/run-rai.sh deleted file mode 100644 index 1d69e68..0000000 --- a/run-rai.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!bin/bash - -njobs=-1 -clf=singlelabel -patience=5 -eval_every=5 -text_len=512 -text_lr=1e-4 -bsize=2 -txt_model=mbert -dataset=rai - -config="-pmt" -echo "[Running gFun config: $config]" -python main.py $config \ - -d $dataset\ - --nosave \ - --n_jobs $njobs \ - --clf_type $clf \ - --patience $patience \ - --evaluate_step $eval_every \ - --batch_size $bsize \ - --max_length $text_len \ - --textual_lr $text_lr \ - --textual_trf_name $txt_model \ \ No newline at end of file