251 lines
8.4 KiB
Python
251 lines
8.4 KiB
Python
# TODO: this should be a instance of an abstract MultilingualDataset
|
|
|
|
from abc import ABC, abstractmethod
|
|
from scipy.sparse import issparse
|
|
from os.path import join, expanduser
|
|
import pickle
|
|
import re
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
|
|
# class RcvMultilingualDataset(MultilingualDataset):
|
|
class RcvMultilingualDataset:
|
|
def __init__(
|
|
self,
|
|
run="0",
|
|
):
|
|
self.dataset_name = "rcv1-2"
|
|
self.dataset_path = expanduser(
|
|
f"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run{run}.pickle"
|
|
)
|
|
|
|
def load(self):
|
|
import pickle
|
|
|
|
data = pickle.load(open(self.dataset_path, "rb"))
|
|
return self
|
|
|
|
|
|
class MultilingualDataset:
|
|
"""
|
|
A multilingual dataset is a dictionary of training and test documents indexed by language code.
|
|
Train and test sets are represented as tuples of the type (X,Y,ids), where X is a matrix representation of the
|
|
documents (e.g., a document-by-term sparse csr_matrix), Y is a document-by-label binary np.array indicating the
|
|
labels of each document, and ids is a list of document-identifiers from the original collection.
|
|
"""
|
|
|
|
def __init__(self, dataset_name):
|
|
self.dataset_name = dataset_name
|
|
self.multiling_dataset = {}
|
|
# print(f"[Init Multilingual Dataset: {self.dataset_name}]")
|
|
|
|
def add(self, lang, Xtr, Ytr, Xte, Yte, tr_ids=None, te_ids=None):
|
|
self.multiling_dataset[lang] = ((Xtr, Ytr, tr_ids), (Xte, Yte, te_ids))
|
|
|
|
def save(self, file):
|
|
self.sort_indexes()
|
|
pickle.dump(self, open(file, "wb"), pickle.HIGHEST_PROTOCOL)
|
|
return self
|
|
|
|
def __getitem__(self, item):
|
|
if item in self.langs():
|
|
return self.multiling_dataset[item]
|
|
return None
|
|
|
|
@classmethod
|
|
def load(cls, file):
|
|
data = pickle.load(open(file, "rb"))
|
|
data.sort_indexes()
|
|
return data
|
|
|
|
@classmethod
|
|
def load_ids(cls, file):
|
|
data = pickle.load(open(file, "rb"))
|
|
tr_ids = {
|
|
lang: tr_ids
|
|
for (lang, ((_, _, tr_ids), (_, _, _))) in data.multiling_dataset.items()
|
|
}
|
|
te_ids = {
|
|
lang: te_ids
|
|
for (lang, ((_, _, _), (_, _, te_ids))) in data.multiling_dataset.items()
|
|
}
|
|
return tr_ids, te_ids
|
|
|
|
def sort_indexes(self):
|
|
for lang, ((Xtr, _, _), (Xte, _, _)) in self.multiling_dataset.items():
|
|
if issparse(Xtr):
|
|
Xtr.sort_indices()
|
|
if issparse(Xte):
|
|
Xte.sort_indices()
|
|
|
|
def set_view(self, categories=None, languages=None):
|
|
if categories is not None:
|
|
if isinstance(categories, int):
|
|
categories = np.array([categories])
|
|
elif isinstance(categories, list):
|
|
categories = np.array(categories)
|
|
self.categories_view = categories
|
|
if languages is not None:
|
|
self.languages_view = languages
|
|
|
|
def training(self, mask_numbers=False, target_as_csr=False):
|
|
return self.lXtr(mask_numbers), self.lYtr(as_csr=target_as_csr)
|
|
|
|
def test(self, mask_numbers=False, target_as_csr=False):
|
|
return self.lXte(mask_numbers), self.lYte(as_csr=target_as_csr)
|
|
|
|
def lXtr(self, mask_numbers=False):
|
|
proc = lambda x: _mask_numbers(x) if mask_numbers else x
|
|
# return {lang: Xtr for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items() if lang in self.langs()}
|
|
return {
|
|
lang: proc(Xtr)
|
|
for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items()
|
|
if lang in self.langs()
|
|
}
|
|
|
|
def lXte(self, mask_numbers=False):
|
|
proc = lambda x: _mask_numbers(x) if mask_numbers else x
|
|
# return {lang: Xte for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items() if lang in self.langs()}
|
|
return {
|
|
lang: proc(Xte)
|
|
for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items()
|
|
if lang in self.langs()
|
|
}
|
|
|
|
def lYtr(self, as_csr=False):
|
|
lY = {
|
|
lang: self.cat_view(Ytr)
|
|
for (lang, ((_, Ytr, _), _)) in self.multiling_dataset.items()
|
|
if lang in self.langs()
|
|
}
|
|
if as_csr:
|
|
lY = {l: csr_matrix(Y) for l, Y in lY.items()}
|
|
return lY
|
|
|
|
def lYte(self, as_csr=False):
|
|
lY = {
|
|
lang: self.cat_view(Yte)
|
|
for (lang, (_, (_, Yte, _))) in self.multiling_dataset.items()
|
|
if lang in self.langs()
|
|
}
|
|
if as_csr:
|
|
lY = {l: csr_matrix(Y) for l, Y in lY.items()}
|
|
return lY
|
|
|
|
def cat_view(self, Y):
|
|
if hasattr(self, "categories_view"):
|
|
return Y[:, self.categories_view]
|
|
else:
|
|
return Y
|
|
|
|
def langs(self):
|
|
if hasattr(self, "languages_view"):
|
|
langs = self.languages_view
|
|
else:
|
|
langs = sorted(self.multiling_dataset.keys())
|
|
return langs
|
|
|
|
def num_labels(self):
|
|
return self.num_categories()
|
|
|
|
def num_categories(self):
|
|
return self.lYtr()[self.langs()[0]].shape[1]
|
|
|
|
def show_dimensions(self):
|
|
def shape(X):
|
|
return X.shape if hasattr(X, "shape") else len(X)
|
|
|
|
for lang, (
|
|
(Xtr, Ytr, IDtr),
|
|
(Xte, Yte, IDte),
|
|
) in self.multiling_dataset.items():
|
|
if lang not in self.langs():
|
|
continue
|
|
print(
|
|
"Lang {}, Xtr={}, ytr={}, Xte={}, yte={}".format(
|
|
lang,
|
|
shape(Xtr),
|
|
self.cat_view(Ytr).shape,
|
|
shape(Xte),
|
|
self.cat_view(Yte).shape,
|
|
)
|
|
)
|
|
|
|
def show_category_prevalences(self):
|
|
nC = self.num_categories()
|
|
accum_tr = np.zeros(nC, dtype=np.int)
|
|
accum_te = np.zeros(nC, dtype=np.int)
|
|
in_langs = np.zeros(
|
|
nC, dtype=np.int
|
|
) # count languages with at least one positive example (per category)
|
|
for lang, (
|
|
(Xtr, Ytr, IDtr),
|
|
(Xte, Yte, IDte),
|
|
) in self.multiling_dataset.items():
|
|
if lang not in self.langs():
|
|
continue
|
|
prev_train = np.sum(self.cat_view(Ytr), axis=0)
|
|
prev_test = np.sum(self.cat_view(Yte), axis=0)
|
|
accum_tr += prev_train
|
|
accum_te += prev_test
|
|
in_langs += (prev_train > 0) * 1
|
|
print(lang + "-train", prev_train)
|
|
print(lang + "-test", prev_test)
|
|
print("all-train", accum_tr)
|
|
print("all-test", accum_te)
|
|
|
|
return accum_tr, accum_te, in_langs
|
|
|
|
def set_labels(self, labels):
|
|
self.labels = labels
|
|
|
|
def reduce_data(self, langs=["it", "en"], maxn=50):
|
|
print(f"- Reducing data: {langs} with max {maxn} documents...\n")
|
|
self.set_view(languages=langs)
|
|
|
|
data = {
|
|
lang: self._reduce(data, maxn)
|
|
for lang, data in self.multiling_dataset.items()
|
|
if lang in langs
|
|
}
|
|
self.multiling_dataset = data
|
|
return self
|
|
|
|
def _reduce(self, multilingual_dataset, maxn):
|
|
new_data = []
|
|
for split in multilingual_dataset:
|
|
docs, labels, ids = split
|
|
if ids is not None:
|
|
new_data.append((docs[:maxn], labels[:maxn], ids[:maxn]))
|
|
else:
|
|
new_data.append((docs[:maxn], labels[:maxn], None))
|
|
return new_data
|
|
|
|
|
|
def _mask_numbers(data):
|
|
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
|
mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
|
|
mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b")
|
|
mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b")
|
|
mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b")
|
|
masked = []
|
|
for text in tqdm(data, desc="masking numbers"):
|
|
text = " " + text
|
|
text = mask_moredigit.sub(" MoreDigitMask", text)
|
|
text = mask_4digit.sub(" FourDigitMask", text)
|
|
text = mask_3digit.sub(" ThreeDigitMask", text)
|
|
text = mask_2digit.sub(" TwoDigitMask", text)
|
|
text = mask_1digit.sub(" OneDigitMask", text)
|
|
masked.append(text.replace(".", "").replace(",", "").strip())
|
|
return masked
|
|
|
|
|
|
if __name__ == "__main__":
|
|
DATAPATH = expanduser(
|
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
|
)
|
|
print(DATAPATH)
|
|
dataset = MultilingualDataset().load(DATAPATH)
|
|
print(dataset.show_dimensions())
|