diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index adf38fa..9374fa1 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -106,7 +106,6 @@ class SimpleGfunDataset: return self.tr_langs, self.te_langs, self.all_langs - def _set_datalang(self, data: pd.DataFrame): return {lang: data[data.lang == lang] for lang in self.all_langs} @@ -162,6 +161,9 @@ class SimpleGfunDataset: one_hot_matrix = np.zeros((len(indices), n_labels)) one_hot_matrix[np.arange(len(indices)), indices] = 1 return one_hot_matrix + + def get_name(self): + return self.name def _mask_numbers(data): diff --git a/infer.py b/infer.py index 39bffce..0a57378 100644 --- a/infer.py +++ b/infer.py @@ -2,9 +2,11 @@ from dataManager.gFunDataset import SimpleGfunDataset from gfun.generalizedFunnelling import GeneralizedFunnelling from main import save_preds +DATASET_NAME = "inference-dataset" # can be anything + def main(args): dataset = SimpleGfunDataset( - dataset_name="inference-dataset", + dataset_name=DATASET_NAME, datadir=args.datadir, multilabel=False, only_inference=True, @@ -15,7 +17,7 @@ def main(args): print("Ok") gfun = GeneralizedFunnelling( - dataset_name="inference", + dataset_name=dataset.get_name(), langs=dataset.langs(), num_labels=dataset.num_labels(), classification_type="singlelabel", @@ -31,7 +33,7 @@ def main(args): save_preds( preds=predictions, targets=None, - config="inference" + config=f"Inference dataset: {dataset.get_name()}" ) if __name__ == "__main__": diff --git a/main.py b/main.py index 8fa47a4..b4d452f 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import os from argparse import ArgumentParser from time import time @@ -138,6 +139,7 @@ def main(args): def save_preds(preds, targets, config="unk", dataset="unk"): + os.makedirs("results/preds") df = pd.DataFrame() langs = sorted(preds.keys()) _preds = []