support for cls dataset; update requirements
This commit is contained in:
parent
25fd67865d
commit
f9d4e50297
|
@ -0,0 +1,81 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import re
|
||||||
|
from dataManager.multilingualDataset import MultilingualDataset
|
||||||
|
|
||||||
|
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
|
||||||
|
LANGS = ["de", "en", "fr", "jp"]
|
||||||
|
DOMAINS = ["books", "dvd", "music"]
|
||||||
|
|
||||||
|
regex = r":\d+"
|
||||||
|
subst = ""
|
||||||
|
|
||||||
|
|
||||||
|
def load_cls():
|
||||||
|
data = {}
|
||||||
|
for lang in LANGS:
|
||||||
|
data[lang] = {}
|
||||||
|
for domain in DOMAINS:
|
||||||
|
print(f"lang: {lang}, domain: {domain}")
|
||||||
|
train = (
|
||||||
|
open(
|
||||||
|
os.path.join(
|
||||||
|
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
|
||||||
|
),
|
||||||
|
"r",
|
||||||
|
)
|
||||||
|
.read()
|
||||||
|
.splitlines()
|
||||||
|
)
|
||||||
|
test = (
|
||||||
|
open(
|
||||||
|
os.path.join(
|
||||||
|
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
|
||||||
|
),
|
||||||
|
"r",
|
||||||
|
)
|
||||||
|
.read()
|
||||||
|
.splitlines()
|
||||||
|
)
|
||||||
|
print(f"train: {len(train)}, test: {len(test)}")
|
||||||
|
data[lang][domain] = {
|
||||||
|
"train": [process_data(t) for t in train],
|
||||||
|
"test": [process_data(t) for t in test],
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def process_data(line):
|
||||||
|
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
|
||||||
|
result = re.sub(regex, subst, line, 0, re.MULTILINE)
|
||||||
|
text, label = result.split("#label#:")
|
||||||
|
label = 0 if label == "negative" else 1
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
|
||||||
|
data = load_cls()
|
||||||
|
multilingualDataset = MultilingualDataset(dataset_name="cls")
|
||||||
|
for lang in LANGS:
|
||||||
|
# TODO: just using book domain atm
|
||||||
|
Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
||||||
|
Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
||||||
|
Xte = [text[0] for text in data[lang]["books"]["test"]]
|
||||||
|
Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
||||||
|
multilingualDataset.add(
|
||||||
|
lang=lang,
|
||||||
|
Xtr=Xtr,
|
||||||
|
Ytr=Ytr,
|
||||||
|
Xte=Xte,
|
||||||
|
Yte=Yte,
|
||||||
|
tr_ids=None,
|
||||||
|
te_ids=None,
|
||||||
|
)
|
||||||
|
multilingualDataset.save(
|
||||||
|
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||||
|
)
|
|
@ -1,6 +1,6 @@
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||||
from dataManager.glamiDataset import get_dataframe
|
from dataManager.glamiDataset import get_dataframe
|
||||||
from dataManager.multilingualDatset import MultilingualDataset
|
from dataManager.multilingualDataset import MultilingualDataset
|
||||||
|
|
||||||
|
|
||||||
class gFunDataset:
|
class gFunDataset:
|
||||||
|
@ -25,7 +25,7 @@ class gFunDataset:
|
||||||
self.load_dataset()
|
self.load_dataset()
|
||||||
|
|
||||||
def get_label_binarizer(self, labels):
|
def get_label_binarizer(self, labels):
|
||||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||||
mlb = "Labels are already binarized for rcv1-2 dataset"
|
mlb = "Labels are already binarized for rcv1-2 dataset"
|
||||||
elif self.is_multilabel:
|
elif self.is_multilabel:
|
||||||
mlb = MultiLabelBinarizer()
|
mlb = MultiLabelBinarizer()
|
||||||
|
@ -60,6 +60,14 @@ class gFunDataset:
|
||||||
)
|
)
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
|
elif "cls" in self.dataset_dir.lower():
|
||||||
|
print(f"- Loading CLS dataset from {self.dataset_dir}")
|
||||||
|
self.dataset_name = "cls"
|
||||||
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||||
|
self.dataset_name, self.dataset_dir, self.nrows
|
||||||
|
)
|
||||||
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
self.show_dimension()
|
self.show_dimension()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -70,7 +78,7 @@ class gFunDataset:
|
||||||
print(
|
print(
|
||||||
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
||||||
)
|
)
|
||||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||||
print(f"-- Labels: {self.labels}")
|
print(f"-- Labels: {self.labels}")
|
||||||
else:
|
else:
|
||||||
print(f"-- Labels: {len(self.labels)}")
|
print(f"-- Labels: {len(self.labels)}")
|
||||||
|
@ -78,7 +86,10 @@ class gFunDataset:
|
||||||
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
||||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
||||||
if nrows is not None:
|
if nrows is not None:
|
||||||
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
|
if dataset_name == "cls":
|
||||||
|
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
|
||||||
|
else:
|
||||||
|
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
|
||||||
labels = old_dataset.num_labels()
|
labels = old_dataset.num_labels()
|
||||||
data_langs = old_dataset.langs()
|
data_langs = old_dataset.langs()
|
||||||
|
|
||||||
|
@ -151,7 +162,7 @@ class gFunDataset:
|
||||||
return dataset, labels, data_langs
|
return dataset, labels, data_langs
|
||||||
|
|
||||||
def binarize_labels(self, labels):
|
def binarize_labels(self, labels):
|
||||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||||
# labels are already binarized for rcv1-2 dataset
|
# labels are already binarized for rcv1-2 dataset
|
||||||
return labels
|
return labels
|
||||||
if hasattr(self, "mlb"):
|
if hasattr(self, "mlb"):
|
||||||
|
|
|
@ -9,32 +9,6 @@ import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class NewMultilingualDataset(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def get_training(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_validation(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_test(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def mask_numbers(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# class RcvMultilingualDataset(MultilingualDataset):
|
# class RcvMultilingualDataset(MultilingualDataset):
|
||||||
class RcvMultilingualDataset:
|
class RcvMultilingualDataset:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -242,7 +216,10 @@ class MultilingualDataset:
|
||||||
new_data = []
|
new_data = []
|
||||||
for split in multilingual_dataset:
|
for split in multilingual_dataset:
|
||||||
docs, labels, ids = split
|
docs, labels, ids = split
|
||||||
new_data.append((docs[:maxn], labels[:maxn], ids[:maxn]))
|
if ids is not None:
|
||||||
|
new_data.append((docs[:maxn], labels[:maxn], ids[:maxn]))
|
||||||
|
else:
|
||||||
|
new_data.append((docs[:maxn], labels[:maxn], None))
|
||||||
return new_data
|
return new_data
|
||||||
|
|
||||||
|
|
18
main.py
18
main.py
|
@ -4,7 +4,7 @@ from os.path import expanduser
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from dataManager.amazonDataset import AmazonDataset
|
from dataManager.amazonDataset import AmazonDataset
|
||||||
from dataManager.multilingualDatset import MultilingualDataset
|
from dataManager.multilingualDataset import MultilingualDataset
|
||||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||||
from dataManager.glamiDataset import GlamiDataset
|
from dataManager.glamiDataset import GlamiDataset
|
||||||
from dataManager.gFunDataset import gFunDataset
|
from dataManager.gFunDataset import gFunDataset
|
||||||
|
@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
|
- [!] add support for Binary Datasets (e.g. cls)
|
||||||
- [!] logging
|
- [!] logging
|
||||||
- add documentations sphinx
|
- add documentations sphinx
|
||||||
- [!] zero-shot setup
|
- [!] zero-shot setup
|
||||||
|
@ -23,12 +24,13 @@ TODO:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(datasetname):
|
def get_dataset(datasetname, args):
|
||||||
assert datasetname in [
|
assert datasetname in [
|
||||||
"multinews",
|
"multinews",
|
||||||
"amazon",
|
"amazon",
|
||||||
"rcv1-2",
|
"rcv1-2",
|
||||||
"glami",
|
"glami",
|
||||||
|
"cls",
|
||||||
], "dataset not supported"
|
], "dataset not supported"
|
||||||
|
|
||||||
RCV_DATAPATH = expanduser(
|
RCV_DATAPATH = expanduser(
|
||||||
|
@ -37,6 +39,8 @@ def get_dataset(datasetname):
|
||||||
JRC_DATAPATH = expanduser(
|
JRC_DATAPATH = expanduser(
|
||||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||||
)
|
)
|
||||||
|
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||||
|
|
||||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||||
|
|
||||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||||
|
@ -73,13 +77,21 @@ def get_dataset(datasetname):
|
||||||
is_multilabel=False,
|
is_multilabel=False,
|
||||||
nrows=args.nrows,
|
nrows=args.nrows,
|
||||||
)
|
)
|
||||||
|
elif datasetname == "cls":
|
||||||
|
dataset = gFunDataset(
|
||||||
|
dataset_dir=CLS_DATAPATH,
|
||||||
|
is_textual=True,
|
||||||
|
is_visual=False,
|
||||||
|
is_multilabel=False,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = get_dataset(args.dataset)
|
dataset = get_dataset(args.dataset, args)
|
||||||
if (
|
if (
|
||||||
isinstance(dataset, MultilingualDataset)
|
isinstance(dataset, MultilingualDataset)
|
||||||
or isinstance(dataset, MultiNewsDataset)
|
or isinstance(dataset, MultiNewsDataset)
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
beautifulsoup4==4.11.2
|
beautifulsoup4==4.11.2
|
||||||
joblib==1.2.0
|
joblib==1.2.0
|
||||||
matplotlib==3.6.3
|
matplotlib==3.7.1
|
||||||
numpy==1.24.1
|
numpy==1.24.2
|
||||||
|
pandas==1.5.3
|
||||||
Pillow==9.4.0
|
Pillow==9.4.0
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
scikit_learn==1.2.1
|
scikit_learn==1.2.1
|
||||||
scipy==1.10.0
|
scipy==1.10.1
|
||||||
torch==1.13.1
|
torch==1.13.1
|
||||||
torchtext==0.14.1
|
torchtext==0.14.1
|
||||||
torchvision==0.14.1
|
tqdm==4.65.0
|
||||||
tqdm==4.64.1
|
transformers==4.26.1
|
||||||
transformers==4.26.0
|
|
||||||
|
|
Loading…
Reference in New Issue