import os import pandas as pd from pathlib import Path from dataManager.gFunDataset import SimpleGfunDataset from gfun.generalizedFunnelling import GeneralizedFunnelling def main(args): dataset = SimpleGfunDataset( dataset_name=Path(args.datapath).stem, datapath=args.datapath, multilabel=False, only_inference=True, labels=args.nlabels, ) lX, _ = dataset.test() gfun = GeneralizedFunnelling( dataset_name=dataset.get_name(), langs=dataset.langs(), num_labels=dataset.num_labels(), classification_type="singlelabel", embed_dir=args.muse_dir, posterior=True, multilingual=True, textual_transformer=True, load_trained=args.trained_gfun, load_meta=True, ) predictions, ids = gfun.transform(lX, output_ids=True) save_inference_preds( preds=predictions, doc_ids=ids, dataset_name=dataset.get_name(), targets=None, category_mapper="models/category_mappers/rai-mapping.csv" ) def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, output_dir="results/inference-preds"): """ Parameters ---------- preds : Dict[str: np.array] Predictions produced by generalized-funnelling. dataset_name: str Dataset name used as output file name. File is stored in directory defined by `output_dir` argument e.g. "/.csv" doc_ids: Dict[str: List[str]] Dictionary storing list of document ids (as defined in the csv input file) targets: Dict[str: np.array] If availabel, target true labels will be written to output file to ease performance evaluation. (default = None) category_mapper: Path Path to the 'category_mapper' csv file storing the category names (str) for each target class (integer). If not None, gFun predictions will be converetd and stored as target string classes. (default=None) output_dir: Path Dir where to store output csv file. (default = results/inference-preds) """ os.makedirs(output_dir, exist_ok=True) df = pd.DataFrame() langs = sorted(preds.keys()) _ids = [] _preds = [] _targets = [] _langs = [] for lang in langs: _preds.extend(preds[lang].argmax(axis=1).tolist()) _langs.extend([lang for i in range(len(preds[lang]))]) _ids.extend(doc_ids[lang]) df["doc_id"] = _ids df["document_language"] = _langs df["gfun_prediction"] = _preds if targets is not None: for lang in langs: _targets.extend(targets[lang].argmax(axis=1).tolist()) df["document_true_label"] = _targets if category_mapper is not None: mapper = pd.read_csv(category_mapper).to_dict()["category"] df["gfun_string_prediction"] = [mapper[p] for p in _preds] output_file = f"{output_dir}/{dataset_name}.csv" print(f"Storing predicitons in: {output_file}") df.to_csv(output_file, index=False) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument("--datapath", type=str, default="~/datasets/rai/csv", help="Path to csv file containing the documents to be classified") parser.add_argument("--nlabels", type=int, default=28) parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings") parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance") args = parser.parse_args() main(args)