From f9d4e502978d8293922ce798f857375f161a6b39 Mon Sep 17 00:00:00 2001 From: Andrea Pedrotti Date: Sat, 4 Mar 2023 12:54:55 +0100 Subject: [PATCH] support for cls dataset; update requirements --- dataManager/clsDataset.py | 81 +++++++++++++++++++ dataManager/gFunDataset.py | 21 +++-- ...ingualDatset.py => multilingualDataset.py} | 31 +------ main.py | 18 ++++- requirements.txt | 12 +-- 5 files changed, 122 insertions(+), 41 deletions(-) create mode 100644 dataManager/clsDataset.py rename dataManager/{multilingualDatset.py => multilingualDataset.py} (94%) diff --git a/dataManager/clsDataset.py b/dataManager/clsDataset.py new file mode 100644 index 0000000..517b395 --- /dev/null +++ b/dataManager/clsDataset.py @@ -0,0 +1,81 @@ +import sys +import os + +sys.path.append(os.getcwd()) + +import numpy as np +import re +from dataManager.multilingualDataset import MultilingualDataset + +CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/") +LANGS = ["de", "en", "fr", "jp"] +DOMAINS = ["books", "dvd", "music"] + +regex = r":\d+" +subst = "" + + +def load_cls(): + data = {} + for lang in LANGS: + data[lang] = {} + for domain in DOMAINS: + print(f"lang: {lang}, domain: {domain}") + train = ( + open( + os.path.join( + CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed" + ), + "r", + ) + .read() + .splitlines() + ) + test = ( + open( + os.path.join( + CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed" + ), + "r", + ) + .read() + .splitlines() + ) + print(f"train: {len(train)}, test: {len(test)}") + data[lang][domain] = { + "train": [process_data(t) for t in train], + "test": [process_data(t) for t in test], + } + return data + + +def process_data(line): + # TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig") + result = re.sub(regex, subst, line, 0, re.MULTILINE) + text, label = result.split("#label#:") + label = 0 if label == "negative" else 1 + return text, label + + +if __name__ == "__main__": + print(f"datapath: {CLS_PROCESSED_DATA_DIR}") + data = load_cls() + multilingualDataset = MultilingualDataset(dataset_name="cls") + for lang in LANGS: + # TODO: just using book domain atm + Xtr = [text[0] for text in data[lang]["books"]["train"]] + Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1) + Xte = [text[0] for text in data[lang]["books"]["test"]] + Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) + multilingualDataset.add( + lang=lang, + Xtr=Xtr, + Ytr=Ytr, + Xte=Xte, + Yte=Yte, + tr_ids=None, + te_ids=None, + ) + multilingualDataset.save( + os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") + ) diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index 5532c6e..0bbf4c9 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -1,6 +1,6 @@ from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from dataManager.glamiDataset import get_dataframe -from dataManager.multilingualDatset import MultilingualDataset +from dataManager.multilingualDataset import MultilingualDataset class gFunDataset: @@ -25,7 +25,7 @@ class gFunDataset: self.load_dataset() def get_label_binarizer(self, labels): - if self.dataset_name in ["rcv1-2", "jrc"]: + if self.dataset_name in ["rcv1-2", "jrc", "cls"]: mlb = "Labels are already binarized for rcv1-2 dataset" elif self.is_multilabel: mlb = MultiLabelBinarizer() @@ -60,6 +60,14 @@ class gFunDataset: ) self.mlb = self.get_label_binarizer(self.labels) + elif "cls" in self.dataset_dir.lower(): + print(f"- Loading CLS dataset from {self.dataset_dir}") + self.dataset_name = "cls" + self.dataset, self.labels, self.data_langs = self._load_multilingual( + self.dataset_name, self.dataset_dir, self.nrows + ) + self.mlb = self.get_label_binarizer(self.labels) + self.show_dimension() return @@ -70,7 +78,7 @@ class gFunDataset: print( f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" ) - if self.dataset_name in ["rcv1-2", "jrc"]: + if self.dataset_name in ["rcv1-2", "jrc", "cls"]: print(f"-- Labels: {self.labels}") else: print(f"-- Labels: {len(self.labels)}") @@ -78,7 +86,10 @@ class gFunDataset: def _load_multilingual(self, dataset_name, dataset_dir, nrows): old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) if nrows is not None: - old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows) + if dataset_name == "cls": + old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows) + else: + old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows) labels = old_dataset.num_labels() data_langs = old_dataset.langs() @@ -151,7 +162,7 @@ class gFunDataset: return dataset, labels, data_langs def binarize_labels(self, labels): - if self.dataset_name in ["rcv1-2", "jrc"]: + if self.dataset_name in ["rcv1-2", "jrc", "cls"]: # labels are already binarized for rcv1-2 dataset return labels if hasattr(self, "mlb"): diff --git a/dataManager/multilingualDatset.py b/dataManager/multilingualDataset.py similarity index 94% rename from dataManager/multilingualDatset.py rename to dataManager/multilingualDataset.py index 6ec2d78..82a2c2a 100644 --- a/dataManager/multilingualDatset.py +++ b/dataManager/multilingualDataset.py @@ -9,32 +9,6 @@ import numpy as np from tqdm import tqdm -class NewMultilingualDataset(ABC): - @abstractmethod - def get_training(self): - pass - - @abstractmethod - def get_validation(self): - pass - - @abstractmethod - def get_test(self): - pass - - @abstractmethod - def mask_numbers(self): - pass - - @abstractmethod - def save(self): - pass - - @abstractmethod - def load(self): - pass - - # class RcvMultilingualDataset(MultilingualDataset): class RcvMultilingualDataset: def __init__( @@ -242,7 +216,10 @@ class MultilingualDataset: new_data = [] for split in multilingual_dataset: docs, labels, ids = split - new_data.append((docs[:maxn], labels[:maxn], ids[:maxn])) + if ids is not None: + new_data.append((docs[:maxn], labels[:maxn], ids[:maxn])) + else: + new_data.append((docs[:maxn], labels[:maxn], None)) return new_data diff --git a/main.py b/main.py index d880fa3..a21a364 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from os.path import expanduser from time import time from dataManager.amazonDataset import AmazonDataset -from dataManager.multilingualDatset import MultilingualDataset +from dataManager.multilingualDataset import MultilingualDataset from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.glamiDataset import GlamiDataset from dataManager.gFunDataset import gFunDataset @@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: + - [!] add support for Binary Datasets (e.g. cls) - [!] logging - add documentations sphinx - [!] zero-shot setup @@ -23,12 +24,13 @@ TODO: """ -def get_dataset(datasetname): +def get_dataset(datasetname, args): assert datasetname in [ "multinews", "amazon", "rcv1-2", "glami", + "cls", ], "dataset not supported" RCV_DATAPATH = expanduser( @@ -37,6 +39,8 @@ def get_dataset(datasetname): JRC_DATAPATH = expanduser( "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" ) + CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") + MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") @@ -73,13 +77,21 @@ def get_dataset(datasetname): is_multilabel=False, nrows=args.nrows, ) + elif datasetname == "cls": + dataset = gFunDataset( + dataset_dir=CLS_DATAPATH, + is_textual=True, + is_visual=False, + is_multilabel=False, + nrows=args.nrows, + ) else: raise NotImplementedError return dataset def main(args): - dataset = get_dataset(args.dataset) + dataset = get_dataset(args.dataset, args) if ( isinstance(dataset, MultilingualDataset) or isinstance(dataset, MultiNewsDataset) diff --git a/requirements.txt b/requirements.txt index c440797..3f39887 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ beautifulsoup4==4.11.2 joblib==1.2.0 -matplotlib==3.6.3 -numpy==1.24.1 +matplotlib==3.7.1 +numpy==1.24.2 +pandas==1.5.3 Pillow==9.4.0 requests==2.28.2 scikit_learn==1.2.1 -scipy==1.10.0 +scipy==1.10.1 torch==1.13.1 torchtext==0.14.1 -torchvision==0.14.1 -tqdm==4.64.1 -transformers==4.26.0 +tqdm==4.65.0 +transformers==4.26.1