import sys import os sys.path.append(os.path.expanduser("~/devel/gfun_multimodal")) from collections import defaultdict, Counter import numpy as np import re from tqdm import tqdm import pandas as pd from sklearn.model_selection import train_test_split from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from dataManager.glamiDataset import get_dataframe from dataManager.multilingualDataset import MultilingualDataset class SimpleGfunDataset: def __init__( self, dataset_name=None, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None, reduced=False ): self.name = dataset_name self.datadir = os.path.expanduser(datadir) self.textual = textual self.visual = visual self.multilabel = multilabel self.reduced = reduced self.load_csv(set_tr_langs, set_te_langs) self.print_stats() def print_stats(self): print(f"Dataset statistics {'-' * 15}") tr = 0 va = 0 te = 0 for lang in self.all_langs: n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0 n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0 n_te = len(self.test_data[lang]) tr += n_tr va += n_va te += n_te print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}") print(f"Total {'-' * 15}") print(f"tr: {tr} - va: {va} - te: {te}") def load_csv(self, set_tr_langs, set_te_langs): _data_tr = pd.read_csv(os.path.join(self.datadir, "train.csv" if not self.reduced else "train.small.csv")) try: stratified = "class" train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label) except: stratified = "lang" train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) print(f"- dataset stratified by {stratified}") test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv")) self._set_langs (train, test, set_tr_langs, set_te_langs) self._set_labels(_data_tr) self.full_train = _data_tr self.full_test = self.test self.train_data = self._set_datalang(train) self.val_data = self._set_datalang(val) self.test_data = self._set_datalang(test) return def _set_labels(self, data): self.labels = sorted(list(data.label.unique())) def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None): self.tr_langs = set(train.lang.unique().tolist()) self.te_langs = set(test.lang.unique().tolist()) if set_tr_langs is not None: print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]") self.tr_langs = self.tr_langs.intersection(set(set_tr_langs)) if set_te_langs is not None: print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]") self.te_langs = self.te_langs.intersection(set(set_te_langs)) self.all_langs = self.tr_langs.union(self.te_langs) return self.tr_langs, self.te_langs, self.all_langs def _set_datalang(self, data: pd.DataFrame): return {lang: data[data.lang == lang] for lang in self.all_langs} def training(self, merge_validation=False, mask_number=False, target_as_csr=False): apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x lXtr = { lang: {"text": apply_mask(self.train_data[lang].text.tolist())} for lang in self.tr_langs } if merge_validation: for lang in self.tr_langs: lXtr[lang]["text"] += apply_mask(self.val_data[lang].text.tolist()) lYtr = { lang: self.train_data[lang].label.tolist() for lang in self.tr_langs } if merge_validation: for lang in self.tr_langs: lYtr[lang] += self.val_data[lang].label.tolist() for lang in self.tr_langs: lYtr[lang] = self.indices_to_one_hot( indices = lYtr[lang], n_labels = self.num_labels() ) return lXtr, lYtr def test(self, mask_number=False, target_as_csr=False): apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x lXte = { lang: {"text": apply_mask(self.test_data[lang].text.tolist())} for lang in self.te_langs } lYte = { lang: self.indices_to_one_hot( indices=self.test_data[lang].label.tolist(), n_labels=self.num_labels()) for lang in self.te_langs } return lXte, lYte def langs(self): return list(self.all_langs) def num_labels(self): return len(self.labels) def indices_to_one_hot(self, indices, n_labels): one_hot_matrix = np.zeros((len(indices), n_labels)) one_hot_matrix[np.arange(len(indices)), indices] = 1 return one_hot_matrix def _mask_numbers(data): mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b") mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b") mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b") mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b") masked = [] for text in tqdm(data, desc="masking numbers", disable=True): text = " " + text text = mask_moredigit.sub(" MoreDigitMask", text) text = mask_4digit.sub(" FourDigitMask", text) text = mask_3digit.sub(" ThreeDigitMask", text) text = mask_2digit.sub(" TwoDigitMask", text) text = mask_1digit.sub(" OneDigitMask", text) masked.append(text.replace(".", "").replace(",", "").strip()) return masked if __name__ == "__main__": data_rai = SimpleGfunDataset() lXtr, lYtr = data_rai.training(mask_number=False) lXte, lYte = data_rai.test(mask_number=False) exit()