49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
from dataManager.gFunDataset import SimpleGfunDataset
|
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|
from main import save_preds
|
|
|
|
def main(args):
|
|
dataset = SimpleGfunDataset(
|
|
dataset_name="inference-dataset",
|
|
datadir=args.datadir,
|
|
multilabel=False,
|
|
only_inference=True,
|
|
labels=args.nlabels,
|
|
)
|
|
|
|
lX, _ = dataset.test()
|
|
print("Ok")
|
|
|
|
gfun = GeneralizedFunnelling(
|
|
dataset_name="inference",
|
|
langs=dataset.langs(),
|
|
num_labels=dataset.num_labels(),
|
|
classification_type="singlelabel",
|
|
embed_dir=args.muse_dir,
|
|
posterior=True,
|
|
multilingual=True,
|
|
textual_transformer=True,
|
|
load_trained=args.trained_gfun,
|
|
load_meta=True,
|
|
trained_text_trf=args.trained_transformer,
|
|
)
|
|
|
|
predictions = gfun.transform(lX)
|
|
save_preds(
|
|
preds=predictions,
|
|
targets=None,
|
|
config="inference"
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
from argparse import ArgumentParser
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!")
|
|
parser.add_argument("--nlabels", type=int, default=28)
|
|
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
|
|
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
|
parser.add_argument("--trained_transformer", default="hf_models/mbert-fewshot-rai-full/checkpoint-5150-small", help="path to fine-tuned transformer")
|
|
args = parser.parse_args()
|
|
main(args)
|
|
|