from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from dataManager.glamiDataset import get_dataframe from dataManager.multilingualDatset import MultilingualDataset class gFunDataset: def __init__( self, dataset_dir, is_textual, is_visual, is_multilabel, labels=None, nrows=None, data_langs=None, ): self.dataset_dir = dataset_dir self.data_langs = data_langs self.is_textual = is_textual self.is_visual = is_visual self.is_multilabel = is_multilabel self.labels = labels self.nrows = nrows self.dataset = {} self.load_dataset() def get_label_binarizer(self, labels): if self.dataset_name in ["rcv1-2", "jrc"]: mlb = "Labels are already binarized for rcv1-2 dataset" elif self.is_multilabel: mlb = MultiLabelBinarizer() mlb.fit([labels]) else: mlb = LabelBinarizer() mlb.fit(labels) return mlb def load_dataset(self): if "glami" in self.dataset_dir.lower(): print(f"- Loading GLAMI dataset from {self.dataset_dir}") self.dataset_name = "glami" self.dataset, self.labels, self.data_langs = self._load_glami( self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) elif "rcv" in self.dataset_dir.lower(): print(f"- Loading RCV1-2 dataset from {self.dataset_dir}") self.dataset_name = "rcv1-2" self.dataset, self.labels, self.data_langs = self._load_multilingual( self.dataset_name, self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) elif "jrc" in self.dataset_dir.lower(): print(f"- Loading JRC dataset from {self.dataset_dir}") self.dataset_name = "jrc" self.dataset, self.labels, self.data_langs = self._load_multilingual( self.dataset_name, self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) self.show_dimension() return def show_dimension(self): print(f"\n[Dataset: {self.dataset_name.upper()}]") for lang, data in self.dataset.items(): print( f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" ) if self.dataset_name in ["rcv1-2", "jrc"]: print(f"-- Labels: {self.labels}") else: print(f"-- Labels: {len(self.labels)}") def _load_multilingual(self, dataset_name, dataset_dir, nrows): old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) if nrows is not None: old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows) labels = old_dataset.num_labels() data_langs = old_dataset.langs() def _format_multilingual(data): text = data[0] image = None labels = data[1] return {"text": text, "image": image, "label": labels} dataset = { k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])} for k, v in old_dataset.multiling_dataset.items() } return dataset, labels, data_langs def _load_glami(self, dataset_dir, nrows): def _balanced_sample(data, n, remainder=0): import pandas as pd langs = sorted(data.geo.unique().tolist()) dict_n = {lang: n for lang in langs} dict_n[langs[0]] += remainder sampled = [] for lang in langs: sampled.append(data[data.geo == lang].sample(n=dict_n[lang])) return pd.concat(sampled, axis=0) # TODO: set this sampling as determinsitic/dependeing on the seed lang_nrows = ( nrows // 13 if self.data_langs is None else nrows // len(self.data_langs) ) # GLAMI 1-M has 13 languages remainder = ( nrows % 13 if self.data_langs is None else nrows % len(self.data_langs) ) train_split = get_dataframe("train", dataset_dir=dataset_dir) train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder) if self.data_langs is None: data_langs = sorted(train_split.geo.unique().tolist()) # TODO: if data langs is NOT none then we have a problem where we filter df by langs if self.labels is None: labels = train_split.category_name.unique().tolist() # TODO: atm test data should contain same languages as train data test_split = get_dataframe("test", dataset_dir=dataset_dir) # TODO: atm we're using 1:1 train-test test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder) gb_train = train_split.groupby("geo") gb_test = test_split.groupby("geo") def _format_glami(data_df): text = (data_df.name + " " + data_df.description).tolist() image = data_df.image_file.tolist() labels = data_df.category_name.tolist() return {"text": text, "image": image, "label": labels} dataset = { lang: { "train": _format_glami(data_tr), "test": _format_glami(gb_test.get_group(lang)), } for lang, data_tr in gb_train if lang in data_langs } return dataset, labels, data_langs def binarize_labels(self, labels): if self.dataset_name in ["rcv1-2", "jrc"]: # labels are already binarized for rcv1-2 dataset return labels if hasattr(self, "mlb"): return self.mlb.transform(labels) else: raise AttributeError("Label binarizer not found") def training(self): lXtr = {} lYtr = {} for lang in self.data_langs: text = self.dataset[lang]["train"]["text"] if self.is_textual else None img = self.dataset[lang]["train"]["image"] if self.is_visual else None labels = self.dataset[lang]["train"]["label"] lXtr[lang] = {"text": text, "image": img} lYtr[lang] = self.binarize_labels(labels) return lXtr, lYtr def test(self): lXte = {} lYte = {} for lang in self.data_langs: text = self.dataset[lang]["test"]["text"] if self.is_textual else None img = self.dataset[lang]["test"]["image"] if self.is_visual else None labels = self.dataset[lang]["test"]["label"] lXte[lang] = {"text": text, "image": img} lYte[lang] = self.binarize_labels(labels) return lXte, lYte def langs(self): return self.data_langs def num_labels(self): if self.dataset_name not in ["rcv1-2", "jrc"]: return len(self.labels) else: return self.labels if __name__ == "__main__": import os GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset") RCV_DATAPATH = os.path.expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) JRC_DATAPATH = os.path.expanduser( "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" ) print("Hello gFunDataset") dataset = gFunDataset( # dataset_dir=GLAMI_DATAPATH, # dataset_dir=RCV_DATAPATH, dataset_dir=JRC_DATAPATH, data_langs=None, is_textual=True, is_visual=True, is_multilabel=False, labels=None, nrows=13, ) lXtr, lYtr = dataset.training() lXte, lYte = dataset.test() exit(0)