diff --git a/dataManager/multilingualDataset.py b/dataManager/multilingualDataset.py index a4c6bc3..fca5dfb 100644 --- a/dataManager/multilingualDataset.py +++ b/dataManager/multilingualDataset.py @@ -227,12 +227,10 @@ class MultilingualDataset: from os.path import expanduser train = pd.read_csv(expanduser(path_tr)) test = pd.read_csv(expanduser(path_te)) - all_labels = set(train.label.to_list()).union(set(test.label.to_list())) for lang in train.lang.unique(): tr_datalang = train.loc[train["lang"] == lang] Xtr = tr_datalang.text.to_list() tr_labels = tr_datalang.label.to_list() - # Ytr = np.zeros((len(Xtr), len(all_labels)), dtype=int) Ytr = np.zeros((len(Xtr), 28), dtype=int) for j, i in enumerate(tr_labels): Ytr[j, i] = 1 @@ -240,7 +238,6 @@ class MultilingualDataset: te_datalang = test.loc[test["lang"] == lang] Xte = te_datalang.text.to_list() te_labels = te_datalang.label.to_list() - # Yte = np.zeros((len(Xte), len(all_labels)), dtype=int) Yte = np.zeros((len(Xte), 28), dtype=int) for j, i in enumerate(te_labels): Yte[j, i] = 1 @@ -257,7 +254,6 @@ class MultilingualDataset: return self - def _mask_numbers(data): mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")