support for cls dataset; update requirements

This commit is contained in:
Andrea Pedrotti 2023-03-04 12:54:55 +01:00
parent 25fd67865d
commit f9d4e50297
5 changed files with 122 additions and 41 deletions

81
dataManager/clsDataset.py Normal file
View File

@ -0,0 +1,81 @@
import sys
import os
sys.path.append(os.getcwd())
import numpy as np
import re
from dataManager.multilingualDataset import MultilingualDataset
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
LANGS = ["de", "en", "fr", "jp"]
DOMAINS = ["books", "dvd", "music"]
regex = r":\d+"
subst = ""
def load_cls():
data = {}
for lang in LANGS:
data[lang] = {}
for domain in DOMAINS:
print(f"lang: {lang}, domain: {domain}")
train = (
open(
os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
),
"r",
)
.read()
.splitlines()
)
test = (
open(
os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
),
"r",
)
.read()
.splitlines()
)
print(f"train: {len(train)}, test: {len(test)}")
data[lang][domain] = {
"train": [process_data(t) for t in train],
"test": [process_data(t) for t in test],
}
return data
def process_data(line):
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
result = re.sub(regex, subst, line, 0, re.MULTILINE)
text, label = result.split("#label#:")
label = 0 if label == "negative" else 1
return text, label
if __name__ == "__main__":
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
data = load_cls()
multilingualDataset = MultilingualDataset(dataset_name="cls")
for lang in LANGS:
# TODO: just using book domain atm
Xtr = [text[0] for text in data[lang]["books"]["train"]]
Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
Xte = [text[0] for text in data[lang]["books"]["test"]]
Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
multilingualDataset.add(
lang=lang,
Xtr=Xtr,
Ytr=Ytr,
Xte=Xte,
Yte=Yte,
tr_ids=None,
te_ids=None,
)
multilingualDataset.save(
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
)

View File

@ -1,6 +1,6 @@
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDatset import MultilingualDataset
from dataManager.multilingualDataset import MultilingualDataset
class gFunDataset:
@ -25,7 +25,7 @@ class gFunDataset:
self.load_dataset()
def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc"]:
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
mlb = "Labels are already binarized for rcv1-2 dataset"
elif self.is_multilabel:
mlb = MultiLabelBinarizer()
@ -60,6 +60,14 @@ class gFunDataset:
)
self.mlb = self.get_label_binarizer(self.labels)
elif "cls" in self.dataset_dir.lower():
print(f"- Loading CLS dataset from {self.dataset_dir}")
self.dataset_name = "cls"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension()
return
@ -70,7 +78,7 @@ class gFunDataset:
print(
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
)
if self.dataset_name in ["rcv1-2", "jrc"]:
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
print(f"-- Labels: {self.labels}")
else:
print(f"-- Labels: {len(self.labels)}")
@ -78,6 +86,9 @@ class gFunDataset:
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
if nrows is not None:
if dataset_name == "cls":
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
else:
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
labels = old_dataset.num_labels()
data_langs = old_dataset.langs()
@ -151,7 +162,7 @@ class gFunDataset:
return dataset, labels, data_langs
def binarize_labels(self, labels):
if self.dataset_name in ["rcv1-2", "jrc"]:
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
# labels are already binarized for rcv1-2 dataset
return labels
if hasattr(self, "mlb"):

View File

@ -9,32 +9,6 @@ import numpy as np
from tqdm import tqdm
class NewMultilingualDataset(ABC):
@abstractmethod
def get_training(self):
pass
@abstractmethod
def get_validation(self):
pass
@abstractmethod
def get_test(self):
pass
@abstractmethod
def mask_numbers(self):
pass
@abstractmethod
def save(self):
pass
@abstractmethod
def load(self):
pass
# class RcvMultilingualDataset(MultilingualDataset):
class RcvMultilingualDataset:
def __init__(
@ -242,7 +216,10 @@ class MultilingualDataset:
new_data = []
for split in multilingual_dataset:
docs, labels, ids = split
if ids is not None:
new_data.append((docs[:maxn], labels[:maxn], ids[:maxn]))
else:
new_data.append((docs[:maxn], labels[:maxn], None))
return new_data

18
main.py
View File

@ -4,7 +4,7 @@ from os.path import expanduser
from time import time
from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDatset import MultilingualDataset
from dataManager.multilingualDataset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.glamiDataset import GlamiDataset
from dataManager.gFunDataset import gFunDataset
@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- [!] add support for Binary Datasets (e.g. cls)
- [!] logging
- add documentations sphinx
- [!] zero-shot setup
@ -23,12 +24,13 @@ TODO:
"""
def get_dataset(datasetname):
def get_dataset(datasetname, args):
assert datasetname in [
"multinews",
"amazon",
"rcv1-2",
"glami",
"cls",
], "dataset not supported"
RCV_DATAPATH = expanduser(
@ -37,6 +39,8 @@ def get_dataset(datasetname):
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
@ -73,13 +77,21 @@ def get_dataset(datasetname):
is_multilabel=False,
nrows=args.nrows,
)
elif datasetname == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset
def main(args):
dataset = get_dataset(args.dataset)
dataset = get_dataset(args.dataset, args)
if (
isinstance(dataset, MultilingualDataset)
or isinstance(dataset, MultiNewsDataset)

View File

@ -1,13 +1,13 @@
beautifulsoup4==4.11.2
joblib==1.2.0
matplotlib==3.6.3
numpy==1.24.1
matplotlib==3.7.1
numpy==1.24.2
pandas==1.5.3
Pillow==9.4.0
requests==2.28.2
scikit_learn==1.2.1
scipy==1.10.0
scipy==1.10.1
torch==1.13.1
torchtext==0.14.1
torchvision==0.14.1
tqdm==4.64.1
transformers==4.26.0
tqdm==4.65.0
transformers==4.26.1