import os from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from dataManager.glamiDataset import get_dataframe from dataManager.multilingualDataset import MultilingualDataset class gFunDataset: def __init__( self, dataset_dir, is_textual, is_visual, is_multilabel, labels=None, nrows=None, data_langs=None, ): self.dataset_dir = dataset_dir self.data_langs = data_langs self.is_textual = is_textual self.is_visual = is_visual self.is_multilabel = is_multilabel self.labels = labels self.nrows = nrows self.dataset = {} self._load_dataset() def get_label_binarizer(self, labels): if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: mlb = f"Labels are already binarized for {self.dataset_name} dataset" elif self.is_multilabel: mlb = MultiLabelBinarizer() mlb.fit([labels]) else: mlb = LabelBinarizer() mlb.fit(labels) return mlb def _load_dataset(self): if "glami" in self.dataset_dir.lower(): print(f"- Loading GLAMI dataset from {self.dataset_dir}") self.dataset_name = "glami" self.dataset, self.labels, self.data_langs = self._load_glami( self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) elif "rcv" in self.dataset_dir.lower(): print(f"- Loading RCV1-2 dataset from {self.dataset_dir}") self.dataset_name = "rcv1-2" self.dataset, self.labels, self.data_langs = self._load_multilingual( self.dataset_name, self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) elif "jrc" in self.dataset_dir.lower(): print(f"- Loading JRC dataset from {self.dataset_dir}") self.dataset_name = "jrc" self.dataset, self.labels, self.data_langs = self._load_multilingual( self.dataset_name, self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) # WEBIS-CLS (processed) elif ( "cls" in self.dataset_dir.lower() and "unprocessed" not in self.dataset_dir.lower() ): print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}") self.dataset_name = "cls" self.dataset, self.labels, self.data_langs = self._load_multilingual( self.dataset_name, self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) # WEBIS-CLS (unprocessed) elif ( "cls" in self.dataset_dir.lower() and "unprocessed" in self.dataset_dir.lower() ): print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}") self.dataset_name = "cls" self.dataset, self.labels, self.data_langs = self._load_multilingual( self.dataset_name, self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) elif "rai" in self.dataset_dir.lower(): print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}") self.dataset_name = "rai" self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name, dataset_dir="~/datasets/rai/csv/train-split-rai.csv", nrows=self.nrows) self.mlb = self.get_label_binarizer(self.labels) self.show_dimension() return def show_dimension(self): print(f"\n[Dataset: {self.dataset_name.upper()}]") for lang, data in self.dataset.items(): print( f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" ) if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: print(f"-- Labels: {self.labels}") else: print(f"-- Labels: {len(self.labels)}") def _load_multilingual(self, dataset_name, dataset_dir, nrows): if "csv" in dataset_dir: old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv( path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv", path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv") else: old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) if nrows is not None: if dataset_name == "cls": old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows) else: old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows) labels = old_dataset.num_labels() data_langs = old_dataset.langs() def _format_multilingual(data): text = data[0] image = None labels = data[1] return {"text": text, "image": image, "label": labels} dataset = { k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])} for k, v in old_dataset.multiling_dataset.items() } return dataset, labels, data_langs def _load_glami(self, dataset_dir, nrows): train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows) test_split = get_dataframe("test", dataset_dir=dataset_dir).sample( n=int(nrows / 10) ) gb_train = train_split.groupby("geo") gb_test = test_split.groupby("geo") if self.data_langs is None: data_langs = sorted(train_split.geo.unique().tolist()) if self.labels is None: labels = train_split.category_name.unique().tolist() def _format_glami(data_df): text = (data_df.name + " " + data_df.description).tolist() image = data_df.image_file.tolist() labels = data_df.category_name.tolist() return {"text": text, "image": image, "label": labels} dataset = { lang: { "train": _format_glami(data_tr), "test": _format_glami(gb_test.get_group(lang)), } for lang, data_tr in gb_train if lang in data_langs } return dataset, labels, data_langs def binarize_labels(self, labels): if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: # labels are already binarized for rcv1-2 dataset return labels if hasattr(self, "mlb"): return self.mlb.transform(labels) else: raise AttributeError("Label binarizer not found") def training(self): lXtr = {} lYtr = {} for lang in self.data_langs: text = self.dataset[lang]["train"]["text"] if self.is_textual else None img = self.dataset[lang]["train"]["image"] if self.is_visual else None labels = self.dataset[lang]["train"]["label"] lXtr[lang] = {"text": text, "image": img} lYtr[lang] = self.binarize_labels(labels) return lXtr, lYtr def test(self): lXte = {} lYte = {} for lang in self.data_langs: text = self.dataset[lang]["test"]["text"] if self.is_textual else None img = self.dataset[lang]["test"]["image"] if self.is_visual else None labels = self.dataset[lang]["test"]["label"] lXte[lang] = {"text": text, "image": img} lYte[lang] = self.binarize_labels(labels) return lXte, lYte def langs(self): return self.data_langs def num_labels(self): if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]: return len(self.labels) else: return self.labels def save_as_pickle(self, path): import pickle filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl") with open(filepath, "wb") as f: print(f"- saving dataset in {filepath}") pickle.dump(self, f) if __name__ == "__main__": import os GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset") RCV_DATAPATH = os.path.expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) JRC_DATAPATH = os.path.expanduser( "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" ) print("Hello gFunDataset") dataset = gFunDataset( dataset_dir=JRC_DATAPATH, data_langs=None, is_textual=True, is_visual=True, is_multilabel=False, labels=None, nrows=13, ) lXtr, lYtr = dataset.training() lXte, lYte = dataset.test() exit(0)