dataset get_name

This commit is contained in:
Andrea Pedrotti 2023-11-06 10:52:07 +01:00
parent 1b58fed14d
commit 4615bc3857
3 changed files with 10 additions and 4 deletions

View File

@ -106,7 +106,6 @@ class SimpleGfunDataset:
return self.tr_langs, self.te_langs, self.all_langs
def _set_datalang(self, data: pd.DataFrame):
return {lang: data[data.lang == lang] for lang in self.all_langs}
@ -162,6 +161,9 @@ class SimpleGfunDataset:
one_hot_matrix = np.zeros((len(indices), n_labels))
one_hot_matrix[np.arange(len(indices)), indices] = 1
return one_hot_matrix
def get_name(self):
return self.name
def _mask_numbers(data):

View File

@ -2,9 +2,11 @@ from dataManager.gFunDataset import SimpleGfunDataset
from gfun.generalizedFunnelling import GeneralizedFunnelling
from main import save_preds
DATASET_NAME = "inference-dataset" # can be anything
def main(args):
dataset = SimpleGfunDataset(
dataset_name="inference-dataset",
dataset_name=DATASET_NAME,
datadir=args.datadir,
multilabel=False,
only_inference=True,
@ -15,7 +17,7 @@ def main(args):
print("Ok")
gfun = GeneralizedFunnelling(
dataset_name="inference",
dataset_name=dataset.get_name(),
langs=dataset.langs(),
num_labels=dataset.num_labels(),
classification_type="singlelabel",
@ -31,7 +33,7 @@ def main(args):
save_preds(
preds=predictions,
targets=None,
config="inference"
config=f"Inference dataset: {dataset.get_name()}"
)
if __name__ == "__main__":

View File

@ -1,3 +1,4 @@
import os
from argparse import ArgumentParser
from time import time
@ -138,6 +139,7 @@ def main(args):
def save_preds(preds, targets, config="unk", dataset="unk"):
os.makedirs("results/preds")
df = pd.DataFrame()
langs = sorted(preds.keys())
_preds = []