webis-unprocessed dataset
This commit is contained in:
parent
b3b7c69263
commit
9ce0001047
|
@ -1,5 +1,6 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
@ -8,13 +9,70 @@ import re
|
||||||
from dataManager.multilingualDataset import MultilingualDataset
|
from dataManager.multilingualDataset import MultilingualDataset
|
||||||
|
|
||||||
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
|
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
|
||||||
LANGS = ["de", "en", "fr", "jp"]
|
CLS_UNPROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-unprocessed/")
|
||||||
|
# LANGS = ["de", "en", "fr", "jp"]
|
||||||
|
LANGS = ["de", "en", "fr"]
|
||||||
DOMAINS = ["books", "dvd", "music"]
|
DOMAINS = ["books", "dvd", "music"]
|
||||||
|
|
||||||
regex = r":\d+"
|
regex = r":\d+"
|
||||||
subst = ""
|
subst = ""
|
||||||
|
|
||||||
|
|
||||||
|
def load_unprocessed_cls(reduce_target_space=False):
|
||||||
|
data = {}
|
||||||
|
for lang in LANGS:
|
||||||
|
data[lang] = {}
|
||||||
|
for domain in DOMAINS:
|
||||||
|
data[lang][domain] = {}
|
||||||
|
print(f"lang: {lang}, domain: {domain}")
|
||||||
|
for split in ["train", "test"]:
|
||||||
|
domain_data = []
|
||||||
|
fdir = os.path.join(
|
||||||
|
CLS_UNPROCESSED_DATA_DIR, lang, domain, f"{split}.review"
|
||||||
|
)
|
||||||
|
tree = ET.parse(fdir)
|
||||||
|
root = tree.getroot()
|
||||||
|
for child in root:
|
||||||
|
if reduce_target_space:
|
||||||
|
rating = np.zeros(3, dtype=int)
|
||||||
|
original_rating = int(float(child.find("rating").text))
|
||||||
|
if original_rating < 3:
|
||||||
|
new_rating = 1
|
||||||
|
elif original_rating > 3:
|
||||||
|
new_rating = 3
|
||||||
|
else:
|
||||||
|
new_rating = 2
|
||||||
|
rating[new_rating - 1] = 1
|
||||||
|
else:
|
||||||
|
rating = np.zeros(5, dtype=int)
|
||||||
|
rating[int(float(child.find("rating").text)) - 1] = 1
|
||||||
|
domain_data.append(
|
||||||
|
{
|
||||||
|
"asin": child.find("asin").text
|
||||||
|
if child.find("asin") is not None
|
||||||
|
else None,
|
||||||
|
"category": child.find("category").text
|
||||||
|
if child.find("category") is not None
|
||||||
|
else None,
|
||||||
|
# "rating": child.find("rating").text
|
||||||
|
# if child.find("rating") is not None
|
||||||
|
# else None,
|
||||||
|
"rating": rating,
|
||||||
|
"title": child.find("title").text
|
||||||
|
if child.find("title") is not None
|
||||||
|
else None,
|
||||||
|
"text": child.find("text").text
|
||||||
|
if child.find("text") is not None
|
||||||
|
else None,
|
||||||
|
"summary": child.find("summary").text
|
||||||
|
if child.find("summary") is not None
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data[lang][domain].update({split: domain_data})
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def load_cls():
|
def load_cls():
|
||||||
data = {}
|
data = {}
|
||||||
for lang in LANGS:
|
for lang in LANGS:
|
||||||
|
@ -24,7 +82,7 @@ def load_cls():
|
||||||
train = (
|
train = (
|
||||||
open(
|
open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
|
CLS_UNPROCESSED_DATA_DIR, lang, domain, "train.processed"
|
||||||
),
|
),
|
||||||
"r",
|
"r",
|
||||||
)
|
)
|
||||||
|
@ -34,7 +92,7 @@ def load_cls():
|
||||||
test = (
|
test = (
|
||||||
open(
|
open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
|
CLS_UNPROCESSED_DATA_DIR, lang, domain, "test.processed"
|
||||||
),
|
),
|
||||||
"r",
|
"r",
|
||||||
)
|
)
|
||||||
|
@ -59,18 +117,29 @@ def process_data(line):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
|
print(f"datapath: {CLS_UNPROCESSED_DATA_DIR}")
|
||||||
data = load_cls()
|
# data = load_cls()
|
||||||
multilingualDataset = MultilingualDataset(dataset_name="cls")
|
data = load_unprocessed_cls(reduce_target_space=True)
|
||||||
for lang in LANGS:
|
multilingualDataset = MultilingualDataset(dataset_name="webis-cls-unprocessed")
|
||||||
# TODO: just using book domain atm
|
|
||||||
Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
|
||||||
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
|
||||||
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
|
|
||||||
|
|
||||||
Xte = [text[0] for text in data[lang]["books"]["test"]]
|
for lang in LANGS:
|
||||||
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
# Xtr = [text["summary"] for text in data[lang]["books"]["train"]]
|
||||||
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
|
Xtr = [text["text"] for text in data[lang]["books"]["train"]]
|
||||||
|
Ytr = np.vstack([text["rating"] for text in data[lang]["books"]["train"]])
|
||||||
|
|
||||||
|
# Xte = [text["summary"] for text in data[lang]["books"]["test"]]
|
||||||
|
Xte = [text["text"] for text in data[lang]["books"]["test"]]
|
||||||
|
Yte = np.vstack([text["rating"] for text in data[lang]["books"]["test"]])
|
||||||
|
|
||||||
|
# for lang in LANGS:
|
||||||
|
# # TODO: just using book domain atm
|
||||||
|
# Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
||||||
|
# # Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
||||||
|
# Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
|
||||||
|
|
||||||
|
# Xte = [text[0] for text in data[lang]["books"]["test"]]
|
||||||
|
# # Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
||||||
|
# Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
|
||||||
|
|
||||||
multilingualDataset.add(
|
multilingualDataset.add(
|
||||||
lang=lang,
|
lang=lang,
|
||||||
|
@ -82,5 +151,7 @@ if __name__ == "__main__":
|
||||||
te_ids=None,
|
te_ids=None,
|
||||||
)
|
)
|
||||||
multilingualDataset.save(
|
multilingualDataset.save(
|
||||||
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
os.path.expanduser(
|
||||||
|
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -62,14 +62,29 @@ class gFunDataset:
|
||||||
)
|
)
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
elif "cls" in self.dataset_dir.lower():
|
# WEBIS-CLS (processed)
|
||||||
print(f"- Loading CLS dataset from {self.dataset_dir}")
|
elif (
|
||||||
|
"cls" in self.dataset_dir.lower()
|
||||||
|
and "unprocessed" not in self.dataset_dir.lower()
|
||||||
|
):
|
||||||
|
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
|
||||||
self.dataset_name = "cls"
|
self.dataset_name = "cls"
|
||||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||||
self.dataset_name, self.dataset_dir, self.nrows
|
self.dataset_name, self.dataset_dir, self.nrows
|
||||||
)
|
)
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
|
# WEBIS-CLS (unprocessed)
|
||||||
|
elif (
|
||||||
|
"cls" in self.dataset_dir.lower()
|
||||||
|
and "unprocessed" in self.dataset_dir.lower()
|
||||||
|
):
|
||||||
|
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
|
||||||
|
self.dataset_name = "cls"
|
||||||
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||||
|
self.dataset_name, self.dataset_dir, self.nrows
|
||||||
|
)
|
||||||
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
self.show_dimension()
|
self.show_dimension()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
@ -23,6 +23,7 @@ def get_dataset(dataset_name, args):
|
||||||
"rcv1-2",
|
"rcv1-2",
|
||||||
"glami",
|
"glami",
|
||||||
"cls",
|
"cls",
|
||||||
|
"webis",
|
||||||
], "dataset not supported"
|
], "dataset not supported"
|
||||||
|
|
||||||
RCV_DATAPATH = expanduser(
|
RCV_DATAPATH = expanduser(
|
||||||
|
@ -37,7 +38,9 @@ def get_dataset(dataset_name, args):
|
||||||
|
|
||||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||||
|
|
||||||
WEBIS_CLS = expanduser("~/dataset/cls-acl10-unprocessed")
|
WEBIS_CLS = expanduser(
|
||||||
|
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl"
|
||||||
|
)
|
||||||
|
|
||||||
if dataset_name == "multinews":
|
if dataset_name == "multinews":
|
||||||
# TODO: convert to gFunDataset
|
# TODO: convert to gFunDataset
|
||||||
|
@ -93,6 +96,15 @@ def get_dataset(dataset_name, args):
|
||||||
is_multilabel=False,
|
is_multilabel=False,
|
||||||
nrows=args.nrows,
|
nrows=args.nrows,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif dataset_name == "webis":
|
||||||
|
dataset = gFunDataset(
|
||||||
|
dataset_dir=WEBIS_CLS,
|
||||||
|
is_textual=True,
|
||||||
|
is_visual=False,
|
||||||
|
is_multilabel=False,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return dataset
|
return dataset
|
||||||
|
|
Loading…
Reference in New Issue