2023-03-02 18:16:46 +01:00
|
|
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
|
|
|
from dataManager.glamiDataset import get_dataframe
|
2023-03-04 12:54:55 +01:00
|
|
|
from dataManager.multilingualDataset import MultilingualDataset
|
2023-03-02 18:16:46 +01:00
|
|
|
|
|
|
|
|
|
|
|
class gFunDataset:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
dataset_dir,
|
|
|
|
is_textual,
|
|
|
|
is_visual,
|
|
|
|
is_multilabel,
|
|
|
|
labels=None,
|
|
|
|
nrows=None,
|
|
|
|
data_langs=None,
|
|
|
|
):
|
|
|
|
self.dataset_dir = dataset_dir
|
|
|
|
self.data_langs = data_langs
|
|
|
|
self.is_textual = is_textual
|
|
|
|
self.is_visual = is_visual
|
|
|
|
self.is_multilabel = is_multilabel
|
|
|
|
self.labels = labels
|
|
|
|
self.nrows = nrows
|
|
|
|
self.dataset = {}
|
|
|
|
self.load_dataset()
|
|
|
|
|
|
|
|
def get_label_binarizer(self, labels):
|
2023-03-04 12:54:55 +01:00
|
|
|
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
2023-03-02 18:16:46 +01:00
|
|
|
mlb = "Labels are already binarized for rcv1-2 dataset"
|
|
|
|
elif self.is_multilabel:
|
|
|
|
mlb = MultiLabelBinarizer()
|
|
|
|
mlb.fit([labels])
|
|
|
|
else:
|
|
|
|
mlb = LabelBinarizer()
|
|
|
|
mlb.fit(labels)
|
|
|
|
return mlb
|
|
|
|
|
|
|
|
def load_dataset(self):
|
|
|
|
if "glami" in self.dataset_dir.lower():
|
|
|
|
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
|
|
|
self.dataset_name = "glami"
|
|
|
|
self.dataset, self.labels, self.data_langs = self._load_glami(
|
|
|
|
self.dataset_dir, self.nrows
|
|
|
|
)
|
|
|
|
self.mlb = self.get_label_binarizer(self.labels)
|
|
|
|
|
|
|
|
elif "rcv" in self.dataset_dir.lower():
|
|
|
|
print(f"- Loading RCV1-2 dataset from {self.dataset_dir}")
|
|
|
|
self.dataset_name = "rcv1-2"
|
|
|
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
|
|
|
self.dataset_name, self.dataset_dir, self.nrows
|
|
|
|
)
|
|
|
|
self.mlb = self.get_label_binarizer(self.labels)
|
|
|
|
|
|
|
|
elif "jrc" in self.dataset_dir.lower():
|
|
|
|
print(f"- Loading JRC dataset from {self.dataset_dir}")
|
|
|
|
self.dataset_name = "jrc"
|
|
|
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
|
|
|
self.dataset_name, self.dataset_dir, self.nrows
|
|
|
|
)
|
|
|
|
self.mlb = self.get_label_binarizer(self.labels)
|
|
|
|
|
2023-03-04 12:54:55 +01:00
|
|
|
elif "cls" in self.dataset_dir.lower():
|
|
|
|
print(f"- Loading CLS dataset from {self.dataset_dir}")
|
|
|
|
self.dataset_name = "cls"
|
|
|
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
|
|
|
self.dataset_name, self.dataset_dir, self.nrows
|
|
|
|
)
|
|
|
|
self.mlb = self.get_label_binarizer(self.labels)
|
|
|
|
|
2023-03-02 18:16:46 +01:00
|
|
|
self.show_dimension()
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
def show_dimension(self):
|
|
|
|
print(f"\n[Dataset: {self.dataset_name.upper()}]")
|
|
|
|
for lang, data in self.dataset.items():
|
|
|
|
print(
|
|
|
|
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
|
|
|
)
|
2023-03-04 12:54:55 +01:00
|
|
|
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
2023-03-02 18:16:46 +01:00
|
|
|
print(f"-- Labels: {self.labels}")
|
|
|
|
else:
|
|
|
|
print(f"-- Labels: {len(self.labels)}")
|
|
|
|
|
|
|
|
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
|
|
|
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
|
|
|
if nrows is not None:
|
2023-03-04 12:54:55 +01:00
|
|
|
if dataset_name == "cls":
|
|
|
|
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
|
|
|
|
else:
|
|
|
|
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
|
2023-03-02 18:16:46 +01:00
|
|
|
labels = old_dataset.num_labels()
|
|
|
|
data_langs = old_dataset.langs()
|
|
|
|
|
|
|
|
def _format_multilingual(data):
|
|
|
|
text = data[0]
|
|
|
|
image = None
|
|
|
|
labels = data[1]
|
|
|
|
return {"text": text, "image": image, "label": labels}
|
|
|
|
|
|
|
|
dataset = {
|
|
|
|
k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])}
|
|
|
|
for k, v in old_dataset.multiling_dataset.items()
|
|
|
|
}
|
|
|
|
return dataset, labels, data_langs
|
|
|
|
|
|
|
|
def _load_glami(self, dataset_dir, nrows):
|
|
|
|
def _balanced_sample(data, n, remainder=0):
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
langs = sorted(data.geo.unique().tolist())
|
|
|
|
dict_n = {lang: n for lang in langs}
|
|
|
|
dict_n[langs[0]] += remainder
|
|
|
|
|
|
|
|
sampled = []
|
|
|
|
for lang in langs:
|
|
|
|
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
|
|
|
|
|
|
|
|
return pd.concat(sampled, axis=0)
|
|
|
|
|
|
|
|
# TODO: set this sampling as determinsitic/dependeing on the seed
|
|
|
|
lang_nrows = (
|
|
|
|
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
|
|
|
|
) # GLAMI 1-M has 13 languages
|
|
|
|
remainder = (
|
|
|
|
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
|
|
|
|
)
|
|
|
|
|
|
|
|
train_split = get_dataframe("train", dataset_dir=dataset_dir)
|
|
|
|
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
|
|
|
|
|
|
|
|
if self.data_langs is None:
|
|
|
|
data_langs = sorted(train_split.geo.unique().tolist())
|
|
|
|
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
|
|
|
|
if self.labels is None:
|
|
|
|
labels = train_split.category_name.unique().tolist()
|
|
|
|
|
|
|
|
# TODO: atm test data should contain same languages as train data
|
|
|
|
test_split = get_dataframe("test", dataset_dir=dataset_dir)
|
|
|
|
# TODO: atm we're using 1:1 train-test
|
|
|
|
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
|
|
|
|
|
|
|
|
gb_train = train_split.groupby("geo")
|
|
|
|
gb_test = test_split.groupby("geo")
|
|
|
|
|
|
|
|
def _format_glami(data_df):
|
|
|
|
text = (data_df.name + " " + data_df.description).tolist()
|
|
|
|
image = data_df.image_file.tolist()
|
|
|
|
labels = data_df.category_name.tolist()
|
|
|
|
return {"text": text, "image": image, "label": labels}
|
|
|
|
|
|
|
|
dataset = {
|
|
|
|
lang: {
|
|
|
|
"train": _format_glami(data_tr),
|
|
|
|
"test": _format_glami(gb_test.get_group(lang)),
|
|
|
|
}
|
|
|
|
for lang, data_tr in gb_train
|
|
|
|
if lang in data_langs
|
|
|
|
}
|
|
|
|
|
|
|
|
return dataset, labels, data_langs
|
|
|
|
|
|
|
|
def binarize_labels(self, labels):
|
2023-03-04 12:54:55 +01:00
|
|
|
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
2023-03-02 18:16:46 +01:00
|
|
|
# labels are already binarized for rcv1-2 dataset
|
|
|
|
return labels
|
|
|
|
if hasattr(self, "mlb"):
|
|
|
|
return self.mlb.transform(labels)
|
|
|
|
else:
|
|
|
|
raise AttributeError("Label binarizer not found")
|
|
|
|
|
|
|
|
def training(self):
|
|
|
|
lXtr = {}
|
|
|
|
lYtr = {}
|
|
|
|
for lang in self.data_langs:
|
|
|
|
text = self.dataset[lang]["train"]["text"] if self.is_textual else None
|
|
|
|
img = self.dataset[lang]["train"]["image"] if self.is_visual else None
|
|
|
|
labels = self.dataset[lang]["train"]["label"]
|
|
|
|
|
|
|
|
lXtr[lang] = {"text": text, "image": img}
|
|
|
|
lYtr[lang] = self.binarize_labels(labels)
|
|
|
|
|
|
|
|
return lXtr, lYtr
|
|
|
|
|
|
|
|
def test(self):
|
|
|
|
lXte = {}
|
|
|
|
lYte = {}
|
|
|
|
for lang in self.data_langs:
|
|
|
|
text = self.dataset[lang]["test"]["text"] if self.is_textual else None
|
|
|
|
img = self.dataset[lang]["test"]["image"] if self.is_visual else None
|
|
|
|
labels = self.dataset[lang]["test"]["label"]
|
|
|
|
|
|
|
|
lXte[lang] = {"text": text, "image": img}
|
|
|
|
lYte[lang] = self.binarize_labels(labels)
|
|
|
|
|
|
|
|
return lXte, lYte
|
|
|
|
|
|
|
|
def langs(self):
|
|
|
|
return self.data_langs
|
|
|
|
|
|
|
|
def num_labels(self):
|
2023-03-06 11:59:47 +01:00
|
|
|
if self.dataset_name not in ["rcv1-2", "jrc", "cls"]:
|
2023-03-02 18:16:46 +01:00
|
|
|
return len(self.labels)
|
|
|
|
else:
|
|
|
|
return self.labels
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import os
|
|
|
|
|
|
|
|
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
|
|
|
RCV_DATAPATH = os.path.expanduser(
|
|
|
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
|
|
|
)
|
|
|
|
JRC_DATAPATH = os.path.expanduser(
|
|
|
|
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
|
|
|
)
|
|
|
|
|
|
|
|
print("Hello gFunDataset")
|
|
|
|
dataset = gFunDataset(
|
|
|
|
# dataset_dir=GLAMI_DATAPATH,
|
|
|
|
# dataset_dir=RCV_DATAPATH,
|
|
|
|
dataset_dir=JRC_DATAPATH,
|
|
|
|
data_langs=None,
|
|
|
|
is_textual=True,
|
|
|
|
is_visual=True,
|
|
|
|
is_multilabel=False,
|
|
|
|
labels=None,
|
|
|
|
nrows=13,
|
|
|
|
)
|
|
|
|
lXtr, lYtr = dataset.training()
|
|
|
|
lXte, lYte = dataset.test()
|
|
|
|
exit(0)
|