From 86fbd90bd4574033b171b33fbeb857cae803d1a2 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Thu, 29 Jun 2023 11:41:22 +0200 Subject: [PATCH] handling new data --- dataManager/gFunDataset.py | 29 ++++++++++++++++-------- dataManager/multilingualDataset.py | 36 +++++++++++++++++++++++++++++- dataManager/utils.py | 13 +++++++++-- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index 243593d..942d718 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -27,8 +27,8 @@ class gFunDataset: self._load_dataset() def get_label_binarizer(self, labels): - if self.dataset_name in ["rcv1-2", "jrc", "cls"]: - mlb = "Labels are already binarized for rcv1-2 dataset" + if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: + mlb = f"Labels are already binarized for {self.dataset_name} dataset" elif self.is_multilabel: mlb = MultiLabelBinarizer() mlb.fit([labels]) @@ -85,8 +85,16 @@ class gFunDataset: self.dataset_name, self.dataset_dir, self.nrows ) self.mlb = self.get_label_binarizer(self.labels) - self.show_dimension() + + elif "rai" in self.dataset_dir.lower(): + print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}") + self.dataset_name = "rai" + self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name, + dataset_dir="~/datasets/rai/csv/train-split-rai.csv", + nrows=self.nrows) + self.mlb = self.get_label_binarizer(self.labels) + self.show_dimension() return def show_dimension(self): @@ -95,13 +103,18 @@ class gFunDataset: print( f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" ) - if self.dataset_name in ["rcv1-2", "jrc", "cls"]: + if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: print(f"-- Labels: {self.labels}") else: print(f"-- Labels: {len(self.labels)}") def _load_multilingual(self, dataset_name, dataset_dir, nrows): - old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) + if "csv" in dataset_dir: + old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv( + path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv", + path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv") + else: + old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) if nrows is not None: if dataset_name == "cls": old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows) @@ -154,7 +167,7 @@ class gFunDataset: return dataset, labels, data_langs def binarize_labels(self, labels): - if self.dataset_name in ["rcv1-2", "jrc", "cls"]: + if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: # labels are already binarized for rcv1-2 dataset return labels if hasattr(self, "mlb"): @@ -192,7 +205,7 @@ class gFunDataset: return self.data_langs def num_labels(self): - if self.dataset_name not in ["rcv1-2", "jrc", "cls"]: + if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]: return len(self.labels) else: return self.labels @@ -219,8 +232,6 @@ if __name__ == "__main__": print("Hello gFunDataset") dataset = gFunDataset( - # dataset_dir=GLAMI_DATAPATH, - # dataset_dir=RCV_DATAPATH, dataset_dir=JRC_DATAPATH, data_langs=None, is_textual=True, diff --git a/dataManager/multilingualDataset.py b/dataManager/multilingualDataset.py index 82a2c2a..a4c6bc3 100644 --- a/dataManager/multilingualDataset.py +++ b/dataManager/multilingualDataset.py @@ -222,6 +222,41 @@ class MultilingualDataset: new_data.append((docs[:maxn], labels[:maxn], None)) return new_data + def from_csv(self, path_tr, path_te): + import pandas as pd + from os.path import expanduser + train = pd.read_csv(expanduser(path_tr)) + test = pd.read_csv(expanduser(path_te)) + all_labels = set(train.label.to_list()).union(set(test.label.to_list())) + for lang in train.lang.unique(): + tr_datalang = train.loc[train["lang"] == lang] + Xtr = tr_datalang.text.to_list() + tr_labels = tr_datalang.label.to_list() + # Ytr = np.zeros((len(Xtr), len(all_labels)), dtype=int) + Ytr = np.zeros((len(Xtr), 28), dtype=int) + for j, i in enumerate(tr_labels): + Ytr[j, i] = 1 + tr_ids = tr_datalang.id.to_list() + te_datalang = test.loc[test["lang"] == lang] + Xte = te_datalang.text.to_list() + te_labels = te_datalang.label.to_list() + # Yte = np.zeros((len(Xte), len(all_labels)), dtype=int) + Yte = np.zeros((len(Xte), 28), dtype=int) + for j, i in enumerate(te_labels): + Yte[j, i] = 1 + te_ids = te_datalang.id.to_list() + self.add( + lang=lang, + Xtr=Xtr, + Ytr=Ytr, + Xte=Xte, + Yte=Yte, + tr_ids=tr_ids, + te_ids=te_ids + ) + return self + + def _mask_numbers(data): mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") @@ -240,7 +275,6 @@ def _mask_numbers(data): masked.append(text.replace(".", "").replace(",", "").strip()) return masked - if __name__ == "__main__": DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" diff --git a/dataManager/utils.py b/dataManager/utils.py index 6270b5f..f7c96b4 100644 --- a/dataManager/utils.py +++ b/dataManager/utils.py @@ -24,6 +24,7 @@ def get_dataset(dataset_name, args): "glami", "cls", "webis", + "rai", ], "dataset not supported" RCV_DATAPATH = expanduser( @@ -42,6 +43,8 @@ def get_dataset(dataset_name, args): "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl" ) + RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl") + if dataset_name == "multinews": # TODO: convert to gFunDataset raise NotImplementedError @@ -87,7 +90,6 @@ def get_dataset(dataset_name, args): ) dataset.save_as_pickle(GLAMI_DATAPATH) - elif dataset_name == "cls": dataset = gFunDataset( dataset_dir=CLS_DATAPATH, @@ -96,7 +98,6 @@ def get_dataset(dataset_name, args): is_multilabel=False, nrows=args.nrows, ) - elif dataset_name == "webis": dataset = gFunDataset( dataset_dir=WEBIS_CLS, @@ -105,6 +106,14 @@ def get_dataset(dataset_name, args): is_multilabel=False, nrows=args.nrows, ) + elif dataset_name == "rai": + dataset = gFunDataset( + dataset_dir=RAI_DATAPATH, + is_textual=True, + is_visual=False, + is_multilabel=False, + nrows=args.nrows + ) else: raise NotImplementedError return dataset