support for cls dataset; update requirements

This commit is contained in:
Andrea Pedrotti 2023-03-04 12:54:55 +01:00
parent 25fd67865d
commit f9d4e50297
5 changed files with 122 additions and 41 deletions

81
dataManager/clsDataset.py Normal file
View File

@ -0,0 +1,81 @@
import sys
import os
sys.path.append(os.getcwd())
import numpy as np
import re
from dataManager.multilingualDataset import MultilingualDataset
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
LANGS = ["de", "en", "fr", "jp"]
DOMAINS = ["books", "dvd", "music"]
regex = r":\d+"
subst = ""
def load_cls():
data = {}
for lang in LANGS:
data[lang] = {}
for domain in DOMAINS:
print(f"lang: {lang}, domain: {domain}")
train = (
open(
os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
),
"r",
)
.read()
.splitlines()
)
test = (
open(
os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
),
"r",
)
.read()
.splitlines()
)
print(f"train: {len(train)}, test: {len(test)}")
data[lang][domain] = {
"train": [process_data(t) for t in train],
"test": [process_data(t) for t in test],
}
return data
def process_data(line):
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
result = re.sub(regex, subst, line, 0, re.MULTILINE)
text, label = result.split("#label#:")
label = 0 if label == "negative" else 1
return text, label
if __name__ == "__main__":
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
data = load_cls()
multilingualDataset = MultilingualDataset(dataset_name="cls")
for lang in LANGS:
# TODO: just using book domain atm
Xtr = [text[0] for text in data[lang]["books"]["train"]]
Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
Xte = [text[0] for text in data[lang]["books"]["test"]]
Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
multilingualDataset.add(
lang=lang,
Xtr=Xtr,
Ytr=Ytr,
Xte=Xte,
Yte=Yte,
tr_ids=None,
te_ids=None,
)
multilingualDataset.save(
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
)

View File

@ -1,6 +1,6 @@
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDatset import MultilingualDataset from dataManager.multilingualDataset import MultilingualDataset
class gFunDataset: class gFunDataset:
@ -25,7 +25,7 @@ class gFunDataset:
self.load_dataset() self.load_dataset()
def get_label_binarizer(self, labels): def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc"]: if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
mlb = "Labels are already binarized for rcv1-2 dataset" mlb = "Labels are already binarized for rcv1-2 dataset"
elif self.is_multilabel: elif self.is_multilabel:
mlb = MultiLabelBinarizer() mlb = MultiLabelBinarizer()
@ -60,6 +60,14 @@ class gFunDataset:
) )
self.mlb = self.get_label_binarizer(self.labels) self.mlb = self.get_label_binarizer(self.labels)
elif "cls" in self.dataset_dir.lower():
print(f"- Loading CLS dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension() self.show_dimension()
return return
@ -70,7 +78,7 @@ class gFunDataset:
print( print(
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
) )
if self.dataset_name in ["rcv1-2", "jrc"]: if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
print(f"-- Labels: {self.labels}") print(f"-- Labels: {self.labels}")
else: else:
print(f"-- Labels: {len(self.labels)}") print(f"-- Labels: {len(self.labels)}")
@ -78,6 +86,9 @@ class gFunDataset:
def _load_multilingual(self, dataset_name, dataset_dir, nrows): def _load_multilingual(self, dataset_name, dataset_dir, nrows):
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
if nrows is not None: if nrows is not None:
if dataset_name == "cls":
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
else:
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows) old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
labels = old_dataset.num_labels() labels = old_dataset.num_labels()
data_langs = old_dataset.langs() data_langs = old_dataset.langs()
@ -151,7 +162,7 @@ class gFunDataset:
return dataset, labels, data_langs return dataset, labels, data_langs
def binarize_labels(self, labels): def binarize_labels(self, labels):
if self.dataset_name in ["rcv1-2", "jrc"]: if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
# labels are already binarized for rcv1-2 dataset # labels are already binarized for rcv1-2 dataset
return labels return labels
if hasattr(self, "mlb"): if hasattr(self, "mlb"):

View File

@ -9,32 +9,6 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
class NewMultilingualDataset(ABC):
@abstractmethod
def get_training(self):
pass
@abstractmethod
def get_validation(self):
pass
@abstractmethod
def get_test(self):
pass
@abstractmethod
def mask_numbers(self):
pass
@abstractmethod
def save(self):
pass
@abstractmethod
def load(self):
pass
# class RcvMultilingualDataset(MultilingualDataset): # class RcvMultilingualDataset(MultilingualDataset):
class RcvMultilingualDataset: class RcvMultilingualDataset:
def __init__( def __init__(
@ -242,7 +216,10 @@ class MultilingualDataset:
new_data = [] new_data = []
for split in multilingual_dataset: for split in multilingual_dataset:
docs, labels, ids = split docs, labels, ids = split
if ids is not None:
new_data.append((docs[:maxn], labels[:maxn], ids[:maxn])) new_data.append((docs[:maxn], labels[:maxn], ids[:maxn]))
else:
new_data.append((docs[:maxn], labels[:maxn], None))
return new_data return new_data

18
main.py
View File

@ -4,7 +4,7 @@ from os.path import expanduser
from time import time from time import time
from dataManager.amazonDataset import AmazonDataset from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDatset import MultilingualDataset from dataManager.multilingualDataset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.glamiDataset import GlamiDataset from dataManager.glamiDataset import GlamiDataset
from dataManager.gFunDataset import gFunDataset from dataManager.gFunDataset import gFunDataset
@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
""" """
TODO: TODO:
- [!] add support for Binary Datasets (e.g. cls)
- [!] logging - [!] logging
- add documentations sphinx - add documentations sphinx
- [!] zero-shot setup - [!] zero-shot setup
@ -23,12 +24,13 @@ TODO:
""" """
def get_dataset(datasetname): def get_dataset(datasetname, args):
assert datasetname in [ assert datasetname in [
"multinews", "multinews",
"amazon", "amazon",
"rcv1-2", "rcv1-2",
"glami", "glami",
"cls",
], "dataset not supported" ], "dataset not supported"
RCV_DATAPATH = expanduser( RCV_DATAPATH = expanduser(
@ -37,6 +39,8 @@ def get_dataset(datasetname):
JRC_DATAPATH = expanduser( JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
) )
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
@ -73,13 +77,21 @@ def get_dataset(datasetname):
is_multilabel=False, is_multilabel=False,
nrows=args.nrows, nrows=args.nrows,
) )
elif datasetname == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else: else:
raise NotImplementedError raise NotImplementedError
return dataset return dataset
def main(args): def main(args):
dataset = get_dataset(args.dataset) dataset = get_dataset(args.dataset, args)
if ( if (
isinstance(dataset, MultilingualDataset) isinstance(dataset, MultilingualDataset)
or isinstance(dataset, MultiNewsDataset) or isinstance(dataset, MultiNewsDataset)

View File

@ -1,13 +1,13 @@
beautifulsoup4==4.11.2 beautifulsoup4==4.11.2
joblib==1.2.0 joblib==1.2.0
matplotlib==3.6.3 matplotlib==3.7.1
numpy==1.24.1 numpy==1.24.2
pandas==1.5.3
Pillow==9.4.0 Pillow==9.4.0
requests==2.28.2 requests==2.28.2
scikit_learn==1.2.1 scikit_learn==1.2.1
scipy==1.10.0 scipy==1.10.1
torch==1.13.1 torch==1.13.1
torchtext==0.14.1 torchtext==0.14.1
torchvision==0.14.1 tqdm==4.65.0
tqdm==4.64.1 transformers==4.26.1
transformers==4.26.0