gfun_multimodal/infer.py

114 lines
4.1 KiB
Python

import os
import pandas as pd
from pathlib import Path
from datetime import datetime
from dataManager.gFunDataset import SimpleGfunDataset
from gfun.generalizedFunnelling import GeneralizedFunnelling
def main(args):
dataset = SimpleGfunDataset(
dataset_name=Path(args.datapath).stem,
datapath=args.datapath,
multilabel=False,
only_inference=True,
labels=args.nlabels,
)
lX, _ = dataset.test()
gfun = GeneralizedFunnelling(
dataset_name=dataset.get_name(),
langs=dataset.langs(),
num_labels=dataset.num_labels(),
classification_type="singlelabel",
embed_dir=args.muse_dir,
posterior=True,
multilingual=True,
textual_transformer=True,
load_trained=args.trained_gfun,
load_meta=True,
)
predictions, ids = gfun.transform(lX, output_ids=True)
save_inference_preds(
preds=predictions,
doc_ids=ids,
dataset_name=dataset.get_name(),
targets=None,
category_mapper=args.category_map,
outdir=args.outdir,
)
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, outdir="results/inference-preds"):
"""
Parameters
----------
preds : Dict[str: np.array]
Predictions produced by generalized-funnelling.
dataset_name: str
Dataset name used as output file name. File is stored in directory defined by `output_dir`
argument e.g. "<output_dir>/<dataset_name>.csv"
doc_ids: Dict[str: List[str]]
Dictionary storing list of document ids (as defined in the csv input file)
targets: Dict[str: np.array]
If availabel, target true labels will be written to output file to ease performance evaluation.
(default = None)
category_mapper: Path
Path to the 'category_mapper' csv file storing the category names (str) for each target class (integer).
If not None, gFun predictions will be converetd and stored as target string classes.
(default=None)
output_dir: Path
Dir where to store output csv file.
(default = results/inference-preds)
"""
os.makedirs(outdir, exist_ok=True)
df = pd.DataFrame()
langs = sorted(preds.keys())
_ids = []
_preds = []
_targets = []
_langs = []
for lang in langs:
_preds.extend(preds[lang].argmax(axis=1).tolist())
_langs.extend([lang for i in range(len(preds[lang]))])
_ids.extend(doc_ids[lang])
df["doc_id"] = _ids
df["document_language"] = _langs
df["gfun_prediction"] = _preds
if targets is not None:
for lang in langs:
_targets.extend(targets[lang].argmax(axis=1).tolist())
df["document_true_label"] = _targets
if category_mapper is not None:
mapper = pd.read_csv(category_mapper).to_dict()["category"]
df["gfun_string_prediction"] = [mapper[p] for p in _preds]
timestamp = datetime.now()
formatted_timestamp = timestamp.strftime("%y%m%d_%H%M%S")
output_file = f"{outdir}/{dataset_name}_{formatted_timestamp}.csv"
print(f"Storing predicitons in: {output_file}")
df.to_csv(output_file, index=False)
return
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--datapath", required=True, type=str, help="path to csv file containing the documents to be classified")
parser.add_argument("--outdir", type=str, default="results/inference-preds", help="path to store csv file containing gfun predictions")
parser.add_argument("--category_map", type=str, default="models/category_mappers/rai-mapping.csv", help="path to csv file containing the mapping from label name to label id [str: id]")
parser.add_argument("--nlabels", type=int, default=28)
parser.add_argument("--muse_dir", type=str, default="embeddings", help="path to muse embeddings")
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
args = parser.parse_args()
main(args)