dataset get_name

This commit is contained in:
Andrea Pedrotti 2023-11-06 10:52:07 +01:00
parent 1b58fed14d
commit 4615bc3857
3 changed files with 10 additions and 4 deletions

View File

@ -106,7 +106,6 @@ class SimpleGfunDataset:
return self.tr_langs, self.te_langs, self.all_langs return self.tr_langs, self.te_langs, self.all_langs
def _set_datalang(self, data: pd.DataFrame): def _set_datalang(self, data: pd.DataFrame):
return {lang: data[data.lang == lang] for lang in self.all_langs} return {lang: data[data.lang == lang] for lang in self.all_langs}
@ -163,6 +162,9 @@ class SimpleGfunDataset:
one_hot_matrix[np.arange(len(indices)), indices] = 1 one_hot_matrix[np.arange(len(indices)), indices] = 1
return one_hot_matrix return one_hot_matrix
def get_name(self):
return self.name
def _mask_numbers(data): def _mask_numbers(data):
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")

View File

@ -2,9 +2,11 @@ from dataManager.gFunDataset import SimpleGfunDataset
from gfun.generalizedFunnelling import GeneralizedFunnelling from gfun.generalizedFunnelling import GeneralizedFunnelling
from main import save_preds from main import save_preds
DATASET_NAME = "inference-dataset" # can be anything
def main(args): def main(args):
dataset = SimpleGfunDataset( dataset = SimpleGfunDataset(
dataset_name="inference-dataset", dataset_name=DATASET_NAME,
datadir=args.datadir, datadir=args.datadir,
multilabel=False, multilabel=False,
only_inference=True, only_inference=True,
@ -15,7 +17,7 @@ def main(args):
print("Ok") print("Ok")
gfun = GeneralizedFunnelling( gfun = GeneralizedFunnelling(
dataset_name="inference", dataset_name=dataset.get_name(),
langs=dataset.langs(), langs=dataset.langs(),
num_labels=dataset.num_labels(), num_labels=dataset.num_labels(),
classification_type="singlelabel", classification_type="singlelabel",
@ -31,7 +33,7 @@ def main(args):
save_preds( save_preds(
preds=predictions, preds=predictions,
targets=None, targets=None,
config="inference" config=f"Inference dataset: {dataset.get_name()}"
) )
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,3 +1,4 @@
import os
from argparse import ArgumentParser from argparse import ArgumentParser
from time import time from time import time
@ -138,6 +139,7 @@ def main(args):
def save_preds(preds, targets, config="unk", dataset="unk"): def save_preds(preds, targets, config="unk", dataset="unk"):
os.makedirs("results/preds")
df = pd.DataFrame() df = pd.DataFrame()
langs = sorted(preds.keys()) langs = sorted(preds.keys())
_preds = [] _preds = []