114 lines
4.1 KiB
Python
114 lines
4.1 KiB
Python
import os
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
from dataManager.gFunDataset import SimpleGfunDataset
|
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|
|
|
|
|
def main(args):
|
|
dataset = SimpleGfunDataset(
|
|
dataset_name=Path(args.datapath).stem,
|
|
datapath=args.datapath,
|
|
multilabel=False,
|
|
only_inference=True,
|
|
labels=args.nlabels,
|
|
)
|
|
|
|
lX, _ = dataset.test()
|
|
|
|
gfun = GeneralizedFunnelling(
|
|
dataset_name=dataset.get_name(),
|
|
langs=dataset.langs(),
|
|
num_labels=dataset.num_labels(),
|
|
classification_type="singlelabel",
|
|
embed_dir=args.muse_dir,
|
|
posterior=True,
|
|
multilingual=True,
|
|
textual_transformer=True,
|
|
load_trained=args.trained_gfun,
|
|
load_meta=True,
|
|
)
|
|
|
|
predictions, ids = gfun.transform(lX, output_ids=True)
|
|
save_inference_preds(
|
|
preds=predictions,
|
|
doc_ids=ids,
|
|
dataset_name=dataset.get_name(),
|
|
targets=None,
|
|
category_mapper=args.category_map,
|
|
outdir=args.outdir,
|
|
)
|
|
|
|
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, outdir="results/inference-preds"):
|
|
"""
|
|
Parameters
|
|
----------
|
|
preds : Dict[str: np.array]
|
|
Predictions produced by generalized-funnelling.
|
|
dataset_name: str
|
|
Dataset name used as output file name. File is stored in directory defined by `output_dir`
|
|
argument e.g. "<output_dir>/<dataset_name>.csv"
|
|
doc_ids: Dict[str: List[str]]
|
|
Dictionary storing list of document ids (as defined in the csv input file)
|
|
targets: Dict[str: np.array]
|
|
If availabel, target true labels will be written to output file to ease performance evaluation.
|
|
(default = None)
|
|
category_mapper: Path
|
|
Path to the 'category_mapper' csv file storing the category names (str) for each target class (integer).
|
|
If not None, gFun predictions will be converetd and stored as target string classes.
|
|
(default=None)
|
|
output_dir: Path
|
|
Dir where to store output csv file.
|
|
(default = results/inference-preds)
|
|
"""
|
|
os.makedirs(outdir, exist_ok=True)
|
|
df = pd.DataFrame()
|
|
langs = sorted(preds.keys())
|
|
_ids = []
|
|
_preds = []
|
|
_targets = []
|
|
_langs = []
|
|
for lang in langs:
|
|
_preds.extend(preds[lang].argmax(axis=1).tolist())
|
|
_langs.extend([lang for i in range(len(preds[lang]))])
|
|
_ids.extend(doc_ids[lang])
|
|
|
|
df["doc_id"] = _ids
|
|
df["document_language"] = _langs
|
|
df["gfun_prediction"] = _preds
|
|
|
|
if targets is not None:
|
|
for lang in langs:
|
|
_targets.extend(targets[lang].argmax(axis=1).tolist())
|
|
df["document_true_label"] = _targets
|
|
|
|
if category_mapper is not None:
|
|
mapper = pd.read_csv(category_mapper).to_dict()["category"]
|
|
df["gfun_string_prediction"] = [mapper[p] for p in _preds]
|
|
|
|
timestamp = datetime.now()
|
|
formatted_timestamp = timestamp.strftime("%y%m%d_%H%M%S")
|
|
|
|
output_file = f"{outdir}/{dataset_name}_{formatted_timestamp}.csv"
|
|
print(f"Storing predicitons in: {output_file}")
|
|
df.to_csv(output_file, index=False)
|
|
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from argparse import ArgumentParser
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--datapath", required=True, type=str, help="path to csv file containing the documents to be classified")
|
|
parser.add_argument("--outdir", type=str, default="results/inference-preds", help="path to store csv file containing gfun predictions")
|
|
parser.add_argument("--category_map", type=str, default="models/category_mappers/rai-mapping.csv", help="path to csv file containing the mapping from label name to label id [str: id]")
|
|
parser.add_argument("--nlabels", type=int, default=28)
|
|
parser.add_argument("--muse_dir", type=str, default="embeddings", help="path to muse embeddings")
|
|
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
|
|
|
args = parser.parse_args()
|
|
main(args)
|
|
|