renamed arguments
This commit is contained in:
parent
59bf921bf3
commit
31fc225d0e
18
infer.py
18
infer.py
|
|
@ -36,10 +36,11 @@ def main(args):
|
||||||
doc_ids=ids,
|
doc_ids=ids,
|
||||||
dataset_name=dataset.get_name(),
|
dataset_name=dataset.get_name(),
|
||||||
targets=None,
|
targets=None,
|
||||||
category_mapper="models/category_mappers/rai-mapping.csv"
|
category_mapper=args.category_map,
|
||||||
|
outdir=args.outdir,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, output_dir="results/inference-preds"):
|
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, outdir="results/inference-preds"):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
@ -61,7 +62,7 @@ def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_ma
|
||||||
Dir where to store output csv file.
|
Dir where to store output csv file.
|
||||||
(default = results/inference-preds)
|
(default = results/inference-preds)
|
||||||
"""
|
"""
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
df = pd.DataFrame()
|
df = pd.DataFrame()
|
||||||
langs = sorted(preds.keys())
|
langs = sorted(preds.keys())
|
||||||
_ids = []
|
_ids = []
|
||||||
|
|
@ -86,18 +87,23 @@ def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_ma
|
||||||
mapper = pd.read_csv(category_mapper).to_dict()["category"]
|
mapper = pd.read_csv(category_mapper).to_dict()["category"]
|
||||||
df["gfun_string_prediction"] = [mapper[p] for p in _preds]
|
df["gfun_string_prediction"] = [mapper[p] for p in _preds]
|
||||||
|
|
||||||
output_file = f"{output_dir}/{dataset_name}.csv"
|
output_file = f"{outdir}/{dataset_name}.csv"
|
||||||
print(f"Storing predicitons in: {output_file}")
|
print(f"Storing predicitons in: {output_file}")
|
||||||
df.to_csv(output_file, index=False)
|
df.to_csv(output_file, index=False)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--datapath", type=str, default="~/datasets/rai/csv", help="Path to csv file containing the documents to be classified")
|
parser.add_argument("--datapath", required=True, type=str, help="path to csv file containing the documents to be classified")
|
||||||
|
parser.add_argument("--outdir", type=str, default="results/inference-preds", help="path to store csv file containing gfun predictions")
|
||||||
|
parser.add_argument("--category_map", type=str, default="models/category_mappers/rai-mapping.csv", help="path to csv file containing the mapping from label name to label id [str: id]")
|
||||||
parser.add_argument("--nlabels", type=int, default=28)
|
parser.add_argument("--nlabels", type=int, default=28)
|
||||||
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
|
parser.add_argument("--muse_dir", type=str, default="embeddings", help="path to muse embeddings")
|
||||||
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|
||||||
|
|
|
||||||
25
run-rai.sh
25
run-rai.sh
|
|
@ -1,25 +0,0 @@
|
||||||
#!bin/bash
|
|
||||||
|
|
||||||
njobs=-1
|
|
||||||
clf=singlelabel
|
|
||||||
patience=5
|
|
||||||
eval_every=5
|
|
||||||
text_len=512
|
|
||||||
text_lr=1e-4
|
|
||||||
bsize=2
|
|
||||||
txt_model=mbert
|
|
||||||
dataset=rai
|
|
||||||
|
|
||||||
config="-pmt"
|
|
||||||
echo "[Running gFun config: $config]"
|
|
||||||
python main.py $config \
|
|
||||||
-d $dataset\
|
|
||||||
--nosave \
|
|
||||||
--n_jobs $njobs \
|
|
||||||
--clf_type $clf \
|
|
||||||
--patience $patience \
|
|
||||||
--evaluate_step $eval_every \
|
|
||||||
--batch_size $bsize \
|
|
||||||
--max_length $text_len \
|
|
||||||
--textual_lr $text_lr \
|
|
||||||
--textual_trf_name $txt_model \
|
|
||||||
Loading…
Reference in New Issue