From 6995854e3d7636469acbff1530a4c17e567ccef5 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Mon, 3 Jul 2023 19:03:42 +0200 Subject: [PATCH] hardcodednlabels f or rai datasets --- dataManager/multilingualDataset.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dataManager/multilingualDataset.py b/dataManager/multilingualDataset.py index a4c6bc3..fca5dfb 100644 --- a/dataManager/multilingualDataset.py +++ b/dataManager/multilingualDataset.py @@ -227,12 +227,10 @@ class MultilingualDataset: from os.path import expanduser train = pd.read_csv(expanduser(path_tr)) test = pd.read_csv(expanduser(path_te)) - all_labels = set(train.label.to_list()).union(set(test.label.to_list())) for lang in train.lang.unique(): tr_datalang = train.loc[train["lang"] == lang] Xtr = tr_datalang.text.to_list() tr_labels = tr_datalang.label.to_list() - # Ytr = np.zeros((len(Xtr), len(all_labels)), dtype=int) Ytr = np.zeros((len(Xtr), 28), dtype=int) for j, i in enumerate(tr_labels): Ytr[j, i] = 1 @@ -240,7 +238,6 @@ class MultilingualDataset: te_datalang = test.loc[test["lang"] == lang] Xte = te_datalang.text.to_list() te_labels = te_datalang.label.to_list() - # Yte = np.zeros((len(Xte), len(all_labels)), dtype=int) Yte = np.zeros((len(Xte), 28), dtype=int) for j, i in enumerate(te_labels): Yte[j, i] = 1 @@ -257,7 +254,6 @@ class MultilingualDataset: return self - def _mask_numbers(data): mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")