gfun_multimodal/dataManager/gFunDataset.py

326 lines
12 KiB
Python

import sys
import os
sys.path.append(os.path.expanduser("~/devel/gfun_multimodal"))
from collections import defaultdict, Counter
import numpy as np
import re
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDataset import MultilingualDataset
class SimpleGfunDataset:
def __init__(
self,
dataset_name=None,
datadir="~/datasets/rai/csv/",
textual=True,
visual=False,
multilabel=False,
set_tr_langs=None,
set_te_langs=None
):
self.name = dataset_name
self.datadir = os.path.expanduser(datadir)
self.textual = textual
self.visual = visual
self.multilabel = multilabel
self.load_csv(set_tr_langs, set_te_langs)
self.print_stats()
def print_stats(self):
print(f"Dataset statistics {'-' * 15}")
tr = 0
va = 0
te = 0
for lang in self.all_langs:
n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0
n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0
n_te = len(self.test_data[lang])
tr += n_tr
va += n_va
te += n_te
print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}")
print(f"Total {'-' * 15}")
print(f"tr: {tr} - va: {va} - te: {te}")
def load_csv(self, set_tr_langs, set_te_langs):
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
try:
stratified = "class"
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label)
except:
stratified = "lang"
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang)
print(f"- dataset stratified by {stratified}")
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
self._set_langs (train, test, set_tr_langs, set_te_langs)
self._set_labels(_data_tr)
self.full_train = _data_tr
self.full_test = self.test
self.train_data = self._set_datalang(train)
self.val_data = self._set_datalang(val)
self.test_data = self._set_datalang(test)
return
def _set_labels(self, data):
self.labels = sorted(list(data.label.unique()))
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
self.tr_langs = set(train.lang.unique().tolist())
self.te_langs = set(test.lang.unique().tolist())
if set_tr_langs is not None:
print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]")
self.tr_langs = self.tr_langs.intersection(set(set_tr_langs))
if set_te_langs is not None:
print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]")
self.te_langs = self.te_langs.intersection(set(set_te_langs))
self.all_langs = self.tr_langs.union(self.te_langs)
return self.tr_langs, self.te_langs, self.all_langs
def _set_datalang(self, data: pd.DataFrame):
return {lang: data[data.lang == lang] for lang in self.all_langs}
def training(self, merge_validation=False, mask_number=False, target_as_csr=False):
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXtr = {
lang: {"text": apply_mask(self.train_data[lang].text.tolist())}
for lang in self.tr_langs
}
if merge_validation:
for lang in self.tr_langs:
lXtr[lang]["text"] += apply_mask(self.val_data[lang].text.tolist())
lYtr = {
lang: self.train_data[lang].label.tolist() for lang in self.tr_langs
}
if merge_validation:
for lang in self.tr_langs:
lYtr[lang] += self.val_data[lang].label.tolist()
for lang in self.tr_langs:
lYtr[lang] = self.indices_to_one_hot(
indices = lYtr[lang],
n_labels = self.num_labels()
)
return lXtr, lYtr
def test(self, mask_number=False, target_as_csr=False):
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXte = {
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
for lang in self.te_langs
}
lYte = {
lang: self.indices_to_one_hot(
indices=self.test_data[lang].label.tolist(),
n_labels=self.num_labels())
for lang in self.te_langs
}
return lXte, lYte
def langs(self):
return list(self.all_langs)
def num_labels(self):
return len(self.labels)
def indices_to_one_hot(self, indices, n_labels):
one_hot_matrix = np.zeros((len(indices), n_labels))
one_hot_matrix[np.arange(len(indices)), indices] = 1
return one_hot_matrix
class gFunDataset:
def __init__(
self,
dataset_dir,
is_textual,
is_visual,
is_multilabel,
labels=None,
nrows=None,
data_langs=None,
):
self.dataset_dir = dataset_dir
self.data_langs = data_langs
self.is_textual = is_textual
self.is_visual = is_visual
self.is_multilabel = is_multilabel
self.labels = labels
self.nrows = nrows
self.dataset = {}
self._load_dataset()
def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
mlb = f"Labels are already binarized for {self.dataset_name} dataset"
elif self.is_multilabel:
mlb = MultiLabelBinarizer()
mlb.fit([labels])
else:
mlb = LabelBinarizer()
mlb.fit(labels)
return mlb
def _load_dataset(self):
print(f"- Loading dataset from {self.dataset_dir}")
self.dataset_name = "rai"
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
dataset_dir=self.dataset_dir,
nrows=self.nrows)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension()
return
def show_dimension(self):
print(f"\n[Dataset: {self.dataset_name.upper()}]")
for lang, data in self.dataset.items():
print(
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
)
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
print(f"-- Labels: {self.labels}")
else:
print(f"-- Labels: {len(self.labels)}")
def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"):
if "csv" in dataset_dir:
old_dataset = MultilingualDataset(dataset_name="rai").from_csv(
path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")),
path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv"))
)
if nrows is not None:
if dataset_name == "cls":
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
else:
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
labels = old_dataset.num_labels()
data_langs = old_dataset.langs()
def _format_multilingual(data):
text = data[0]
image = None
labels = data[1]
return {"text": text, "image": image, "label": labels}
dataset = {
k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])}
for k, v in old_dataset.multiling_dataset.items()
}
return dataset, labels, data_langs
def _load_glami(self, dataset_dir, nrows):
train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
n=int(nrows / 10)
)
gb_train = train_split.groupby("geo")
gb_test = test_split.groupby("geo")
if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist())
if self.labels is None:
labels = train_split.category_name.unique().tolist()
def _format_glami(data_df):
text = (data_df.name + " " + data_df.description).tolist()
image = data_df.image_file.tolist()
labels = data_df.category_name.tolist()
return {"text": text, "image": image, "label": labels}
dataset = {
lang: {
"train": _format_glami(data_tr),
"test": _format_glami(gb_test.get_group(lang)),
}
for lang, data_tr in gb_train
if lang in data_langs
}
return dataset, labels, data_langs
def binarize_labels(self, labels):
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
# labels are already binarized for rcv1-2 dataset
return labels
if hasattr(self, "mlb"):
return self.mlb.transform(labels)
else:
raise AttributeError("Label binarizer not found")
def training(self):
lXtr = {}
lYtr = {}
for lang in self.data_langs:
text = self.dataset[lang]["train"]["text"] if self.is_textual else None
img = self.dataset[lang]["train"]["image"] if self.is_visual else None
labels = self.dataset[lang]["train"]["label"]
lXtr[lang] = {"text": text, "image": img}
lYtr[lang] = self.binarize_labels(labels)
return lXtr, lYtr
def test(self):
lXte = {}
lYte = {}
for lang in self.data_langs:
text = self.dataset[lang]["test"]["text"] if self.is_textual else None
img = self.dataset[lang]["test"]["image"] if self.is_visual else None
labels = self.dataset[lang]["test"]["label"]
lXte[lang] = {"text": text, "image": img}
lYte[lang] = self.binarize_labels(labels)
return lXte, lYte
def langs(self):
return self.data_langs
def num_labels(self):
if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]:
return len(self.labels)
else:
return self.labels
def save_as_pickle(self, path):
import pickle
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
with open(filepath, "wb") as f:
print(f"- saving dataset in {filepath}")
pickle.dump(self, f)
def _mask_numbers(data):
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b")
mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b")
mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b")
masked = []
for text in tqdm(data, desc="masking numbers", disable=True):
text = " " + text
text = mask_moredigit.sub(" MoreDigitMask", text)
text = mask_4digit.sub(" FourDigitMask", text)
text = mask_3digit.sub(" ThreeDigitMask", text)
text = mask_2digit.sub(" TwoDigitMask", text)
text = mask_1digit.sub(" OneDigitMask", text)
masked.append(text.replace(".", "").replace(",", "").strip())
return masked
if __name__ == "__main__":
data_rai = SimpleGfunDataset()
lXtr, lYtr = data_rai.training(mask_number=False)
lXte, lYte = data_rai.test(mask_number=False)
exit()