support for cls dataset; update requirements
This commit is contained in:
parent
25fd67865d
commit
f9d4e50297
|
@ -0,0 +1,81 @@
|
|||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import numpy as np
|
||||
import re
|
||||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
|
||||
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
|
||||
LANGS = ["de", "en", "fr", "jp"]
|
||||
DOMAINS = ["books", "dvd", "music"]
|
||||
|
||||
regex = r":\d+"
|
||||
subst = ""
|
||||
|
||||
|
||||
def load_cls():
|
||||
data = {}
|
||||
for lang in LANGS:
|
||||
data[lang] = {}
|
||||
for domain in DOMAINS:
|
||||
print(f"lang: {lang}, domain: {domain}")
|
||||
train = (
|
||||
open(
|
||||
os.path.join(
|
||||
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
|
||||
),
|
||||
"r",
|
||||
)
|
||||
.read()
|
||||
.splitlines()
|
||||
)
|
||||
test = (
|
||||
open(
|
||||
os.path.join(
|
||||
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
|
||||
),
|
||||
"r",
|
||||
)
|
||||
.read()
|
||||
.splitlines()
|
||||
)
|
||||
print(f"train: {len(train)}, test: {len(test)}")
|
||||
data[lang][domain] = {
|
||||
"train": [process_data(t) for t in train],
|
||||
"test": [process_data(t) for t in test],
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def process_data(line):
|
||||
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
|
||||
result = re.sub(regex, subst, line, 0, re.MULTILINE)
|
||||
text, label = result.split("#label#:")
|
||||
label = 0 if label == "negative" else 1
|
||||
return text, label
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
|
||||
data = load_cls()
|
||||
multilingualDataset = MultilingualDataset(dataset_name="cls")
|
||||
for lang in LANGS:
|
||||
# TODO: just using book domain atm
|
||||
Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
||||
Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
||||
Xte = [text[0] for text in data[lang]["books"]["test"]]
|
||||
Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
||||
multilingualDataset.add(
|
||||
lang=lang,
|
||||
Xtr=Xtr,
|
||||
Ytr=Ytr,
|
||||
Xte=Xte,
|
||||
Yte=Yte,
|
||||
tr_ids=None,
|
||||
te_ids=None,
|
||||
)
|
||||
multilingualDataset.save(
|
||||
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||
)
|
|
@ -1,6 +1,6 @@
|
|||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from dataManager.glamiDataset import get_dataframe
|
||||
from dataManager.multilingualDatset import MultilingualDataset
|
||||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
|
||||
|
||||
class gFunDataset:
|
||||
|
@ -25,7 +25,7 @@ class gFunDataset:
|
|||
self.load_dataset()
|
||||
|
||||
def get_label_binarizer(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
mlb = "Labels are already binarized for rcv1-2 dataset"
|
||||
elif self.is_multilabel:
|
||||
mlb = MultiLabelBinarizer()
|
||||
|
@ -60,6 +60,14 @@ class gFunDataset:
|
|||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
elif "cls" in self.dataset_dir.lower():
|
||||
print(f"- Loading CLS dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "cls"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
self.show_dimension()
|
||||
|
||||
return
|
||||
|
@ -70,7 +78,7 @@ class gFunDataset:
|
|||
print(
|
||||
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
||||
)
|
||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
print(f"-- Labels: {self.labels}")
|
||||
else:
|
||||
print(f"-- Labels: {len(self.labels)}")
|
||||
|
@ -78,6 +86,9 @@ class gFunDataset:
|
|||
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
||||
if nrows is not None:
|
||||
if dataset_name == "cls":
|
||||
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
|
||||
else:
|
||||
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
|
||||
labels = old_dataset.num_labels()
|
||||
data_langs = old_dataset.langs()
|
||||
|
@ -151,7 +162,7 @@ class gFunDataset:
|
|||
return dataset, labels, data_langs
|
||||
|
||||
def binarize_labels(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
# labels are already binarized for rcv1-2 dataset
|
||||
return labels
|
||||
if hasattr(self, "mlb"):
|
||||
|
|
|
@ -9,32 +9,6 @@ import numpy as np
|
|||
from tqdm import tqdm
|
||||
|
||||
|
||||
class NewMultilingualDataset(ABC):
|
||||
@abstractmethod
|
||||
def get_training(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_validation(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_test(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mask_numbers(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
|
||||
# class RcvMultilingualDataset(MultilingualDataset):
|
||||
class RcvMultilingualDataset:
|
||||
def __init__(
|
||||
|
@ -242,7 +216,10 @@ class MultilingualDataset:
|
|||
new_data = []
|
||||
for split in multilingual_dataset:
|
||||
docs, labels, ids = split
|
||||
if ids is not None:
|
||||
new_data.append((docs[:maxn], labels[:maxn], ids[:maxn]))
|
||||
else:
|
||||
new_data.append((docs[:maxn], labels[:maxn], None))
|
||||
return new_data
|
||||
|
||||
|
18
main.py
18
main.py
|
@ -4,7 +4,7 @@ from os.path import expanduser
|
|||
from time import time
|
||||
|
||||
from dataManager.amazonDataset import AmazonDataset
|
||||
from dataManager.multilingualDatset import MultilingualDataset
|
||||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.glamiDataset import GlamiDataset
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
|
@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|||
|
||||
"""
|
||||
TODO:
|
||||
- [!] add support for Binary Datasets (e.g. cls)
|
||||
- [!] logging
|
||||
- add documentations sphinx
|
||||
- [!] zero-shot setup
|
||||
|
@ -23,12 +24,13 @@ TODO:
|
|||
"""
|
||||
|
||||
|
||||
def get_dataset(datasetname):
|
||||
def get_dataset(datasetname, args):
|
||||
assert datasetname in [
|
||||
"multinews",
|
||||
"amazon",
|
||||
"rcv1-2",
|
||||
"glami",
|
||||
"cls",
|
||||
], "dataset not supported"
|
||||
|
||||
RCV_DATAPATH = expanduser(
|
||||
|
@ -37,6 +39,8 @@ def get_dataset(datasetname):
|
|||
JRC_DATAPATH = expanduser(
|
||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
)
|
||||
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||
|
||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
|
@ -73,13 +77,21 @@ def get_dataset(datasetname):
|
|||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif datasetname == "cls":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=CLS_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
||||
|
||||
|
||||
def main(args):
|
||||
dataset = get_dataset(args.dataset)
|
||||
dataset = get_dataset(args.dataset, args)
|
||||
if (
|
||||
isinstance(dataset, MultilingualDataset)
|
||||
or isinstance(dataset, MultiNewsDataset)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
beautifulsoup4==4.11.2
|
||||
joblib==1.2.0
|
||||
matplotlib==3.6.3
|
||||
numpy==1.24.1
|
||||
matplotlib==3.7.1
|
||||
numpy==1.24.2
|
||||
pandas==1.5.3
|
||||
Pillow==9.4.0
|
||||
requests==2.28.2
|
||||
scikit_learn==1.2.1
|
||||
scipy==1.10.0
|
||||
scipy==1.10.1
|
||||
torch==1.13.1
|
||||
torchtext==0.14.1
|
||||
torchvision==0.14.1
|
||||
tqdm==4.64.1
|
||||
transformers==4.26.0
|
||||
tqdm==4.65.0
|
||||
transformers==4.26.1
|
||||
|
|
Loading…
Reference in New Issue