renamed arguments
This commit is contained in:
parent
59bf921bf3
commit
31fc225d0e
18
infer.py
18
infer.py
|
|
@ -36,10 +36,11 @@ def main(args):
|
|||
doc_ids=ids,
|
||||
dataset_name=dataset.get_name(),
|
||||
targets=None,
|
||||
category_mapper="models/category_mappers/rai-mapping.csv"
|
||||
category_mapper=args.category_map,
|
||||
outdir=args.outdir,
|
||||
)
|
||||
|
||||
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, output_dir="results/inference-preds"):
|
||||
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, outdir="results/inference-preds"):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -61,7 +62,7 @@ def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_ma
|
|||
Dir where to store output csv file.
|
||||
(default = results/inference-preds)
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
df = pd.DataFrame()
|
||||
langs = sorted(preds.keys())
|
||||
_ids = []
|
||||
|
|
@ -86,18 +87,23 @@ def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_ma
|
|||
mapper = pd.read_csv(category_mapper).to_dict()["category"]
|
||||
df["gfun_string_prediction"] = [mapper[p] for p in _preds]
|
||||
|
||||
output_file = f"{output_dir}/{dataset_name}.csv"
|
||||
output_file = f"{outdir}/{dataset_name}.csv"
|
||||
print(f"Storing predicitons in: {output_file}")
|
||||
df.to_csv(output_file, index=False)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--datapath", type=str, default="~/datasets/rai/csv", help="Path to csv file containing the documents to be classified")
|
||||
parser.add_argument("--datapath", required=True, type=str, help="path to csv file containing the documents to be classified")
|
||||
parser.add_argument("--outdir", type=str, default="results/inference-preds", help="path to store csv file containing gfun predictions")
|
||||
parser.add_argument("--category_map", type=str, default="models/category_mappers/rai-mapping.csv", help="path to csv file containing the mapping from label name to label id [str: id]")
|
||||
parser.add_argument("--nlabels", type=int, default=28)
|
||||
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
|
||||
parser.add_argument("--muse_dir", type=str, default="embeddings", help="path to muse embeddings")
|
||||
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
|
|
|||
25
run-rai.sh
25
run-rai.sh
|
|
@ -1,25 +0,0 @@
|
|||
#!bin/bash
|
||||
|
||||
njobs=-1
|
||||
clf=singlelabel
|
||||
patience=5
|
||||
eval_every=5
|
||||
text_len=512
|
||||
text_lr=1e-4
|
||||
bsize=2
|
||||
txt_model=mbert
|
||||
dataset=rai
|
||||
|
||||
config="-pmt"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model \
|
||||
Loading…
Reference in New Issue