import sys import os sys.path.append(os.path.expanduser("~/devel/gfun_multimodal")) from collections import defaultdict, Counter import numpy as np import re from tqdm import tqdm import pandas as pd from sklearn.model_selection import train_test_split from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from dataManager.glamiDataset import get_dataframe from dataManager.multilingualDataset import MultilingualDataset class SimpleGfunDataset: def __init__( self, dataset_name=None, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None ): self.name = dataset_name self.datadir = os.path.expanduser(datadir) self.textual = textual self.visual = visual self.multilabel = multilabel self.load_csv(set_tr_langs, set_te_langs) self.print_stats() def print_stats(self): print(f"Dataset statistics {'-' * 15}") tr = 0 va = 0 te = 0 for lang in self.all_langs: n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0 n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0 n_te = len(self.test_data[lang]) tr += n_tr va += n_va te += n_te print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}") print(f"Total {'-' * 15}") print(f"tr: {tr} - va: {va} - te: {te}") def load_csv(self, set_tr_langs, set_te_langs): _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv")) try: stratified = "class" train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label) except: stratified = "lang" train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) print(f"- dataset stratified by {stratified}") test = pd.read_csv(os.path.join(self.datadir, "test.small.csv")) self._set_langs (train, test, set_tr_langs, set_te_langs) self._set_labels(_data_tr) self.full_train = _data_tr self.full_test = self.test self.train_data = self._set_datalang(train) self.val_data = self._set_datalang(val) self.test_data = self._set_datalang(test) return def _set_labels(self, data): self.labels = sorted(list(data.label.unique())) def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None): self.tr_langs = set(train.lang.unique().tolist()) self.te_langs = set(test.lang.unique().tolist()) if set_tr_langs is not None: print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]") self.tr_langs = self.tr_langs.intersection(set(set_tr_langs)) if set_te_langs is not None: print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]") self.te_langs = self.te_langs.intersection(set(set_te_langs)) self.all_langs = self.tr_langs.union(self.te_langs) return self.tr_langs, self.te_langs, self.all_langs def _set_datalang(self, data: pd.DataFrame): return {lang: data[data.lang == lang] for lang in self.all_langs} def training(self, merge_validation=False, mask_number=False, target_as_csr=False): apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x lXtr = { lang: {"text": apply_mask(self.train_data[lang].text.tolist())} for lang in self.tr_langs } if merge_validation: for lang in self.tr_langs: lXtr[lang]["text"] += apply_mask(self.val_data[lang].text.tolist()) lYtr = { lang: self.train_data[lang].label.tolist() for lang in self.tr_langs } if merge_validation: for lang in self.tr_langs: lYtr[lang] += self.val_data[lang].label.tolist() for lang in self.tr_langs: lYtr[lang] = self.indices_to_one_hot( indices = lYtr[lang], n_labels = self.num_labels() ) return lXtr, lYtr def test(self, mask_number=False, target_as_csr=False): apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x lXte = { lang: {"text": apply_mask(self.test_data[lang].text.tolist())} for lang in self.te_langs } lYte = { lang: self.indices_to_one_hot( indices=self.test_data[lang].label.tolist(), n_labels=self.num_labels()) for lang in self.te_langs } return lXte, lYte def langs(self): return list(self.all_langs) def num_labels(self): return len(self.labels) def indices_to_one_hot(self, indices, n_labels): one_hot_matrix = np.zeros((len(indices), n_labels)) one_hot_matrix[np.arange(len(indices)), indices] = 1 return one_hot_matrix class gFunDataset: def __init__( self, dataset_dir, is_textual, is_visual, is_multilabel, labels=None, nrows=None, data_langs=None, ): self.dataset_dir = dataset_dir self.data_langs = data_langs self.is_textual = is_textual self.is_visual = is_visual self.is_multilabel = is_multilabel self.labels = labels self.nrows = nrows self.dataset = {} self._load_dataset() def get_label_binarizer(self, labels): if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: mlb = f"Labels are already binarized for {self.dataset_name} dataset" elif self.is_multilabel: mlb = MultiLabelBinarizer() mlb.fit([labels]) else: mlb = LabelBinarizer() mlb.fit(labels) return mlb def _load_dataset(self): print(f"- Loading dataset from {self.dataset_dir}") self.dataset_name = "rai" self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name, dataset_dir=self.dataset_dir, nrows=self.nrows) self.mlb = self.get_label_binarizer(self.labels) self.show_dimension() return def show_dimension(self): print(f"\n[Dataset: {self.dataset_name.upper()}]") for lang, data in self.dataset.items(): print( f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" ) if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: print(f"-- Labels: {self.labels}") else: print(f"-- Labels: {len(self.labels)}") def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"): if "csv" in dataset_dir: old_dataset = MultilingualDataset(dataset_name="rai").from_csv( path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")), path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv")) ) if nrows is not None: if dataset_name == "cls": old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows) else: old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows) labels = old_dataset.num_labels() data_langs = old_dataset.langs() def _format_multilingual(data): text = data[0] image = None labels = data[1] return {"text": text, "image": image, "label": labels} dataset = { k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])} for k, v in old_dataset.multiling_dataset.items() } return dataset, labels, data_langs def _load_glami(self, dataset_dir, nrows): train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows) test_split = get_dataframe("test", dataset_dir=dataset_dir).sample( n=int(nrows / 10) ) gb_train = train_split.groupby("geo") gb_test = test_split.groupby("geo") if self.data_langs is None: data_langs = sorted(train_split.geo.unique().tolist()) if self.labels is None: labels = train_split.category_name.unique().tolist() def _format_glami(data_df): text = (data_df.name + " " + data_df.description).tolist() image = data_df.image_file.tolist() labels = data_df.category_name.tolist() return {"text": text, "image": image, "label": labels} dataset = { lang: { "train": _format_glami(data_tr), "test": _format_glami(gb_test.get_group(lang)), } for lang, data_tr in gb_train if lang in data_langs } return dataset, labels, data_langs def binarize_labels(self, labels): if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: # labels are already binarized for rcv1-2 dataset return labels if hasattr(self, "mlb"): return self.mlb.transform(labels) else: raise AttributeError("Label binarizer not found") def training(self): lXtr = {} lYtr = {} for lang in self.data_langs: text = self.dataset[lang]["train"]["text"] if self.is_textual else None img = self.dataset[lang]["train"]["image"] if self.is_visual else None labels = self.dataset[lang]["train"]["label"] lXtr[lang] = {"text": text, "image": img} lYtr[lang] = self.binarize_labels(labels) return lXtr, lYtr def test(self): lXte = {} lYte = {} for lang in self.data_langs: text = self.dataset[lang]["test"]["text"] if self.is_textual else None img = self.dataset[lang]["test"]["image"] if self.is_visual else None labels = self.dataset[lang]["test"]["label"] lXte[lang] = {"text": text, "image": img} lYte[lang] = self.binarize_labels(labels) return lXte, lYte def langs(self): return self.data_langs def num_labels(self): if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]: return len(self.labels) else: return self.labels def save_as_pickle(self, path): import pickle filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl") with open(filepath, "wb") as f: print(f"- saving dataset in {filepath}") pickle.dump(self, f) def _mask_numbers(data): mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b") mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b") mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b") mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b") masked = [] for text in tqdm(data, desc="masking numbers", disable=True): text = " " + text text = mask_moredigit.sub(" MoreDigitMask", text) text = mask_4digit.sub(" FourDigitMask", text) text = mask_3digit.sub(" ThreeDigitMask", text) text = mask_2digit.sub(" TwoDigitMask", text) text = mask_1digit.sub(" OneDigitMask", text) masked.append(text.replace(".", "").replace(",", "").strip()) return masked if __name__ == "__main__": data_rai = SimpleGfunDataset() lXtr, lYtr = data_rai.training(mask_number=False) lXte, lYte = data_rai.test(mask_number=False) exit()