renamed arguments

This commit is contained in:
Andrea Pedrotti 2024-03-12 13:26:37 +01:00
parent 59bf921bf3
commit 31fc225d0e
2 changed files with 12 additions and 31 deletions

View File

@ -36,10 +36,11 @@ def main(args):
doc_ids=ids, doc_ids=ids,
dataset_name=dataset.get_name(), dataset_name=dataset.get_name(),
targets=None, targets=None,
category_mapper="models/category_mappers/rai-mapping.csv" category_mapper=args.category_map,
outdir=args.outdir,
) )
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, output_dir="results/inference-preds"): def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, outdir="results/inference-preds"):
""" """
Parameters Parameters
---------- ----------
@ -61,7 +62,7 @@ def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_ma
Dir where to store output csv file. Dir where to store output csv file.
(default = results/inference-preds) (default = results/inference-preds)
""" """
os.makedirs(output_dir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
df = pd.DataFrame() df = pd.DataFrame()
langs = sorted(preds.keys()) langs = sorted(preds.keys())
_ids = [] _ids = []
@ -86,18 +87,23 @@ def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_ma
mapper = pd.read_csv(category_mapper).to_dict()["category"] mapper = pd.read_csv(category_mapper).to_dict()["category"]
df["gfun_string_prediction"] = [mapper[p] for p in _preds] df["gfun_string_prediction"] = [mapper[p] for p in _preds]
output_file = f"{output_dir}/{dataset_name}.csv" output_file = f"{outdir}/{dataset_name}.csv"
print(f"Storing predicitons in: {output_file}") print(f"Storing predicitons in: {output_file}")
df.to_csv(output_file, index=False) df.to_csv(output_file, index=False)
return
if __name__ == "__main__": if __name__ == "__main__":
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--datapath", type=str, default="~/datasets/rai/csv", help="Path to csv file containing the documents to be classified") parser.add_argument("--datapath", required=True, type=str, help="path to csv file containing the documents to be classified")
parser.add_argument("--outdir", type=str, default="results/inference-preds", help="path to store csv file containing gfun predictions")
parser.add_argument("--category_map", type=str, default="models/category_mappers/rai-mapping.csv", help="path to csv file containing the mapping from label name to label id [str: id]")
parser.add_argument("--nlabels", type=int, default=28) parser.add_argument("--nlabels", type=int, default=28)
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings") parser.add_argument("--muse_dir", type=str, default="embeddings", help="path to muse embeddings")
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance") parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,25 +0,0 @@
#!bin/bash
njobs=-1
clf=singlelabel
patience=5
eval_every=5
text_len=512
text_lr=1e-4
bsize=2
txt_model=mbert
dataset=rai
config="-pmt"
echo "[Running gFun config: $config]"
python main.py $config \
-d $dataset\
--nosave \
--n_jobs $njobs \
--clf_type $clf \
--patience $patience \
--evaluate_step $eval_every \
--batch_size $bsize \
--max_length $text_len \
--textual_lr $text_lr \
--textual_trf_name $txt_model \