dataset get_name
This commit is contained in:
parent
1b58fed14d
commit
4615bc3857
|
|
@ -106,7 +106,6 @@ class SimpleGfunDataset:
|
||||||
|
|
||||||
return self.tr_langs, self.te_langs, self.all_langs
|
return self.tr_langs, self.te_langs, self.all_langs
|
||||||
|
|
||||||
|
|
||||||
def _set_datalang(self, data: pd.DataFrame):
|
def _set_datalang(self, data: pd.DataFrame):
|
||||||
return {lang: data[data.lang == lang] for lang in self.all_langs}
|
return {lang: data[data.lang == lang] for lang in self.all_langs}
|
||||||
|
|
||||||
|
|
@ -163,6 +162,9 @@ class SimpleGfunDataset:
|
||||||
one_hot_matrix[np.arange(len(indices)), indices] = 1
|
one_hot_matrix[np.arange(len(indices)), indices] = 1
|
||||||
return one_hot_matrix
|
return one_hot_matrix
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
def _mask_numbers(data):
|
def _mask_numbers(data):
|
||||||
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
||||||
|
|
|
||||||
8
infer.py
8
infer.py
|
|
@ -2,9 +2,11 @@ from dataManager.gFunDataset import SimpleGfunDataset
|
||||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
from main import save_preds
|
from main import save_preds
|
||||||
|
|
||||||
|
DATASET_NAME = "inference-dataset" # can be anything
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = SimpleGfunDataset(
|
dataset = SimpleGfunDataset(
|
||||||
dataset_name="inference-dataset",
|
dataset_name=DATASET_NAME,
|
||||||
datadir=args.datadir,
|
datadir=args.datadir,
|
||||||
multilabel=False,
|
multilabel=False,
|
||||||
only_inference=True,
|
only_inference=True,
|
||||||
|
|
@ -15,7 +17,7 @@ def main(args):
|
||||||
print("Ok")
|
print("Ok")
|
||||||
|
|
||||||
gfun = GeneralizedFunnelling(
|
gfun = GeneralizedFunnelling(
|
||||||
dataset_name="inference",
|
dataset_name=dataset.get_name(),
|
||||||
langs=dataset.langs(),
|
langs=dataset.langs(),
|
||||||
num_labels=dataset.num_labels(),
|
num_labels=dataset.num_labels(),
|
||||||
classification_type="singlelabel",
|
classification_type="singlelabel",
|
||||||
|
|
@ -31,7 +33,7 @@ def main(args):
|
||||||
save_preds(
|
save_preds(
|
||||||
preds=predictions,
|
preds=predictions,
|
||||||
targets=None,
|
targets=None,
|
||||||
config="inference"
|
config=f"Inference dataset: {dataset.get_name()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
2
main.py
2
main.py
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
@ -138,6 +139,7 @@ def main(args):
|
||||||
|
|
||||||
|
|
||||||
def save_preds(preds, targets, config="unk", dataset="unk"):
|
def save_preds(preds, targets, config="unk", dataset="unk"):
|
||||||
|
os.makedirs("results/preds")
|
||||||
df = pd.DataFrame()
|
df = pd.DataFrame()
|
||||||
langs = sorted(preds.keys())
|
langs = sorted(preds.keys())
|
||||||
_preds = []
|
_preds = []
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue