dataset get_name
This commit is contained in:
parent
1b58fed14d
commit
4615bc3857
|
|
@ -106,7 +106,6 @@ class SimpleGfunDataset:
|
|||
|
||||
return self.tr_langs, self.te_langs, self.all_langs
|
||||
|
||||
|
||||
def _set_datalang(self, data: pd.DataFrame):
|
||||
return {lang: data[data.lang == lang] for lang in self.all_langs}
|
||||
|
||||
|
|
@ -162,6 +161,9 @@ class SimpleGfunDataset:
|
|||
one_hot_matrix = np.zeros((len(indices), n_labels))
|
||||
one_hot_matrix[np.arange(len(indices)), indices] = 1
|
||||
return one_hot_matrix
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
|
||||
|
||||
def _mask_numbers(data):
|
||||
|
|
|
|||
8
infer.py
8
infer.py
|
|
@ -2,9 +2,11 @@ from dataManager.gFunDataset import SimpleGfunDataset
|
|||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
from main import save_preds
|
||||
|
||||
DATASET_NAME = "inference-dataset" # can be anything
|
||||
|
||||
def main(args):
|
||||
dataset = SimpleGfunDataset(
|
||||
dataset_name="inference-dataset",
|
||||
dataset_name=DATASET_NAME,
|
||||
datadir=args.datadir,
|
||||
multilabel=False,
|
||||
only_inference=True,
|
||||
|
|
@ -15,7 +17,7 @@ def main(args):
|
|||
print("Ok")
|
||||
|
||||
gfun = GeneralizedFunnelling(
|
||||
dataset_name="inference",
|
||||
dataset_name=dataset.get_name(),
|
||||
langs=dataset.langs(),
|
||||
num_labels=dataset.num_labels(),
|
||||
classification_type="singlelabel",
|
||||
|
|
@ -31,7 +33,7 @@ def main(args):
|
|||
save_preds(
|
||||
preds=predictions,
|
||||
targets=None,
|
||||
config="inference"
|
||||
config=f"Inference dataset: {dataset.get_name()}"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue