gfun_multimodal/infer.py

49 lines
1.6 KiB
Python

from dataManager.gFunDataset import SimpleGfunDataset
from gfun.generalizedFunnelling import GeneralizedFunnelling
from main import save_preds
DATASET_NAME = "inference-dataset" # can be anything
def main(args):
dataset = SimpleGfunDataset(
dataset_name=DATASET_NAME,
datadir=args.datadir,
multilabel=False,
only_inference=True,
labels=args.nlabels,
)
lX, _ = dataset.test()
print("Ok")
gfun = GeneralizedFunnelling(
dataset_name=dataset.get_name(),
langs=dataset.langs(),
num_labels=dataset.num_labels(),
classification_type="singlelabel",
embed_dir=args.muse_dir,
posterior=True,
multilingual=True,
textual_transformer=True,
load_trained=args.trained_gfun,
load_meta=True,
)
predictions = gfun.transform(lX)
save_preds(
preds=predictions,
targets=None,
config=f"Inference dataset: {dataset.get_name()}"
)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!")
parser.add_argument("--nlabels", type=int, default=28)
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
args = parser.parse_args()
main(args)