hardcodednlabels f or rai datasets
This commit is contained in:
parent
55e12505c0
commit
6995854e3d
|
|
@ -227,12 +227,10 @@ class MultilingualDataset:
|
||||||
from os.path import expanduser
|
from os.path import expanduser
|
||||||
train = pd.read_csv(expanduser(path_tr))
|
train = pd.read_csv(expanduser(path_tr))
|
||||||
test = pd.read_csv(expanduser(path_te))
|
test = pd.read_csv(expanduser(path_te))
|
||||||
all_labels = set(train.label.to_list()).union(set(test.label.to_list()))
|
|
||||||
for lang in train.lang.unique():
|
for lang in train.lang.unique():
|
||||||
tr_datalang = train.loc[train["lang"] == lang]
|
tr_datalang = train.loc[train["lang"] == lang]
|
||||||
Xtr = tr_datalang.text.to_list()
|
Xtr = tr_datalang.text.to_list()
|
||||||
tr_labels = tr_datalang.label.to_list()
|
tr_labels = tr_datalang.label.to_list()
|
||||||
# Ytr = np.zeros((len(Xtr), len(all_labels)), dtype=int)
|
|
||||||
Ytr = np.zeros((len(Xtr), 28), dtype=int)
|
Ytr = np.zeros((len(Xtr), 28), dtype=int)
|
||||||
for j, i in enumerate(tr_labels):
|
for j, i in enumerate(tr_labels):
|
||||||
Ytr[j, i] = 1
|
Ytr[j, i] = 1
|
||||||
|
|
@ -240,7 +238,6 @@ class MultilingualDataset:
|
||||||
te_datalang = test.loc[test["lang"] == lang]
|
te_datalang = test.loc[test["lang"] == lang]
|
||||||
Xte = te_datalang.text.to_list()
|
Xte = te_datalang.text.to_list()
|
||||||
te_labels = te_datalang.label.to_list()
|
te_labels = te_datalang.label.to_list()
|
||||||
# Yte = np.zeros((len(Xte), len(all_labels)), dtype=int)
|
|
||||||
Yte = np.zeros((len(Xte), 28), dtype=int)
|
Yte = np.zeros((len(Xte), 28), dtype=int)
|
||||||
for j, i in enumerate(te_labels):
|
for j, i in enumerate(te_labels):
|
||||||
Yte[j, i] = 1
|
Yte[j, i] = 1
|
||||||
|
|
@ -257,7 +254,6 @@ class MultilingualDataset:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _mask_numbers(data):
|
def _mask_numbers(data):
|
||||||
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
||||||
mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
|
mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue