gfun_multimodal/dataManager/multilingualDataset.py

251 lines
8.4 KiB
Python

# TODO: this should be a instance of an abstract MultilingualDataset
from abc import ABC, abstractmethod
from scipy.sparse import issparse
from os.path import join, expanduser
import pickle
import re
import numpy as np
from tqdm import tqdm
# class RcvMultilingualDataset(MultilingualDataset):
class RcvMultilingualDataset:
def __init__(
self,
run="0",
):
self.dataset_name = "rcv1-2"
self.dataset_path = expanduser(
f"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run{run}.pickle"
)
def load(self):
import pickle
data = pickle.load(open(self.dataset_path, "rb"))
return self
class MultilingualDataset:
"""
A multilingual dataset is a dictionary of training and test documents indexed by language code.
Train and test sets are represented as tuples of the type (X,Y,ids), where X is a matrix representation of the
documents (e.g., a document-by-term sparse csr_matrix), Y is a document-by-label binary np.array indicating the
labels of each document, and ids is a list of document-identifiers from the original collection.
"""
def __init__(self, dataset_name):
self.dataset_name = dataset_name
self.multiling_dataset = {}
# print(f"[Init Multilingual Dataset: {self.dataset_name}]")
def add(self, lang, Xtr, Ytr, Xte, Yte, tr_ids=None, te_ids=None):
self.multiling_dataset[lang] = ((Xtr, Ytr, tr_ids), (Xte, Yte, te_ids))
def save(self, file):
self.sort_indexes()
pickle.dump(self, open(file, "wb"), pickle.HIGHEST_PROTOCOL)
return self
def __getitem__(self, item):
if item in self.langs():
return self.multiling_dataset[item]
return None
@classmethod
def load(cls, file):
data = pickle.load(open(file, "rb"))
data.sort_indexes()
return data
@classmethod
def load_ids(cls, file):
data = pickle.load(open(file, "rb"))
tr_ids = {
lang: tr_ids
for (lang, ((_, _, tr_ids), (_, _, _))) in data.multiling_dataset.items()
}
te_ids = {
lang: te_ids
for (lang, ((_, _, _), (_, _, te_ids))) in data.multiling_dataset.items()
}
return tr_ids, te_ids
def sort_indexes(self):
for lang, ((Xtr, _, _), (Xte, _, _)) in self.multiling_dataset.items():
if issparse(Xtr):
Xtr.sort_indices()
if issparse(Xte):
Xte.sort_indices()
def set_view(self, categories=None, languages=None):
if categories is not None:
if isinstance(categories, int):
categories = np.array([categories])
elif isinstance(categories, list):
categories = np.array(categories)
self.categories_view = categories
if languages is not None:
self.languages_view = languages
def training(self, mask_numbers=False, target_as_csr=False):
return self.lXtr(mask_numbers), self.lYtr(as_csr=target_as_csr)
def test(self, mask_numbers=False, target_as_csr=False):
return self.lXte(mask_numbers), self.lYte(as_csr=target_as_csr)
def lXtr(self, mask_numbers=False):
proc = lambda x: _mask_numbers(x) if mask_numbers else x
# return {lang: Xtr for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items() if lang in self.langs()}
return {
lang: proc(Xtr)
for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items()
if lang in self.langs()
}
def lXte(self, mask_numbers=False):
proc = lambda x: _mask_numbers(x) if mask_numbers else x
# return {lang: Xte for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items() if lang in self.langs()}
return {
lang: proc(Xte)
for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items()
if lang in self.langs()
}
def lYtr(self, as_csr=False):
lY = {
lang: self.cat_view(Ytr)
for (lang, ((_, Ytr, _), _)) in self.multiling_dataset.items()
if lang in self.langs()
}
if as_csr:
lY = {l: csr_matrix(Y) for l, Y in lY.items()}
return lY
def lYte(self, as_csr=False):
lY = {
lang: self.cat_view(Yte)
for (lang, (_, (_, Yte, _))) in self.multiling_dataset.items()
if lang in self.langs()
}
if as_csr:
lY = {l: csr_matrix(Y) for l, Y in lY.items()}
return lY
def cat_view(self, Y):
if hasattr(self, "categories_view"):
return Y[:, self.categories_view]
else:
return Y
def langs(self):
if hasattr(self, "languages_view"):
langs = self.languages_view
else:
langs = sorted(self.multiling_dataset.keys())
return langs
def num_labels(self):
return self.num_categories()
def num_categories(self):
return self.lYtr()[self.langs()[0]].shape[1]
def show_dimensions(self):
def shape(X):
return X.shape if hasattr(X, "shape") else len(X)
for lang, (
(Xtr, Ytr, IDtr),
(Xte, Yte, IDte),
) in self.multiling_dataset.items():
if lang not in self.langs():
continue
print(
"Lang {}, Xtr={}, ytr={}, Xte={}, yte={}".format(
lang,
shape(Xtr),
self.cat_view(Ytr).shape,
shape(Xte),
self.cat_view(Yte).shape,
)
)
def show_category_prevalences(self):
nC = self.num_categories()
accum_tr = np.zeros(nC, dtype=np.int)
accum_te = np.zeros(nC, dtype=np.int)
in_langs = np.zeros(
nC, dtype=np.int
) # count languages with at least one positive example (per category)
for lang, (
(Xtr, Ytr, IDtr),
(Xte, Yte, IDte),
) in self.multiling_dataset.items():
if lang not in self.langs():
continue
prev_train = np.sum(self.cat_view(Ytr), axis=0)
prev_test = np.sum(self.cat_view(Yte), axis=0)
accum_tr += prev_train
accum_te += prev_test
in_langs += (prev_train > 0) * 1
print(lang + "-train", prev_train)
print(lang + "-test", prev_test)
print("all-train", accum_tr)
print("all-test", accum_te)
return accum_tr, accum_te, in_langs
def set_labels(self, labels):
self.labels = labels
def reduce_data(self, langs=["it", "en"], maxn=50):
print(f"- Reducing data: {langs} with max {maxn} documents...\n")
self.set_view(languages=langs)
data = {
lang: self._reduce(data, maxn)
for lang, data in self.multiling_dataset.items()
if lang in langs
}
self.multiling_dataset = data
return self
def _reduce(self, multilingual_dataset, maxn):
new_data = []
for split in multilingual_dataset:
docs, labels, ids = split
if ids is not None:
new_data.append((docs[:maxn], labels[:maxn], ids[:maxn]))
else:
new_data.append((docs[:maxn], labels[:maxn], None))
return new_data
def _mask_numbers(data):
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b")
mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b")
mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b")
masked = []
for text in tqdm(data, desc="masking numbers"):
text = " " + text
text = mask_moredigit.sub(" MoreDigitMask", text)
text = mask_4digit.sub(" FourDigitMask", text)
text = mask_3digit.sub(" ThreeDigitMask", text)
text = mask_2digit.sub(" TwoDigitMask", text)
text = mask_1digit.sub(" OneDigitMask", text)
masked.append(text.replace(".", "").replace(",", "").strip())
return masked
if __name__ == "__main__":
DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
print(DATAPATH)
dataset = MultilingualDataset().load(DATAPATH)
print(dataset.show_dimensions())