from dataManager.gFunDataset import SimpleGfunDataset from gfun.generalizedFunnelling import GeneralizedFunnelling from main import save_preds def main(args): dataset = SimpleGfunDataset( dataset_name="inference-dataset", datadir=args.datadir, multilabel=False, only_inference=True, labels=args.nlabels, ) lX, _ = dataset.test() print("Ok") gfun = GeneralizedFunnelling( dataset_name="inference", langs=dataset.langs(), num_labels=dataset.num_labels(), classification_type="singlelabel", embed_dir=args.muse_dir, posterior=True, multilingual=True, textual_transformer=True, load_trained=args.trained_gfun, load_meta=True, trained_text_trf=args.trained_transformer, ) predictions = gfun.transform(lX) save_preds( preds=predictions, targets=None, config="inference" ) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!") parser.add_argument("--nlabels", type=int, default=28) parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings") parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance") parser.add_argument("--trained_transformer", default="hf_models/mbert-fewshot-rai-full/checkpoint-5150-small", help="path to fine-tuned transformer") args = parser.parse_args() main(args)