gfun_multimodal/dataManager/clsDataset.py

179 lines
6.4 KiB
Python

import sys
import os
import xml.etree.ElementTree as ET
sys.path.append(os.getcwd())
import numpy as np
import re
from dataManager.multilingualDataset import MultilingualDataset
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
CLS_UNPROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-unprocessed/")
# LANGS = ["de", "en", "fr", "jp"]
LANGS = ["de", "en", "fr"]
DOMAINS = ["books", "dvd", "music"]
regex = r":\d+"
subst = ""
def load_unprocessed_cls(reduce_target_space=False):
data = {}
data_tr = []
data_te = []
c_tr = 0
c_te = 0
for lang in LANGS:
data[lang] = {}
for domain in DOMAINS:
data[lang][domain] = {}
print(f"lang: {lang}, domain: {domain}")
for split in ["train", "test"]:
domain_data = []
fdir = os.path.join(
CLS_UNPROCESSED_DATA_DIR, lang, domain, f"{split}.review"
)
tree = ET.parse(fdir)
root = tree.getroot()
for child in root:
if reduce_target_space:
rating = np.zeros(3, dtype=int)
original_rating = int(float(child.find("rating").text))
if original_rating < 3:
new_rating = 1
elif original_rating > 3:
new_rating = 3
else:
new_rating = 2
rating[new_rating - 1] = 1
# rating = new_rating
else:
rating = np.zeros(5, dtype=int)
rating[int(float(child.find("rating").text)) - 1] = 1
# rating = new_rating
# if split == "train":
# target_data = data_tr
# current_count = len(target_data)
# c_tr = +1
# else:
# target_data = data_te
# current_count = len(target_data)
# c_te = +1
domain_data.append(
# target_data.append(
{
"asin": child.find("asin").text
if child.find("asin") is not None
else None,
# "category": child.find("category").text
# if child.find("category") is not None
# else None,
"category": domain,
# "rating": child.find("rating").text
# if child.find("rating") is not None
# else None,
"rating": rating,
"title": child.find("title").text
if child.find("title") is not None
else None,
"text": child.find("text").text
if child.find("text") is not None
else None,
"summary": child.find("summary").text
if child.find("summary") is not None
else None,
"lang": lang,
}
)
data[lang][domain].update({split: domain_data})
return data
def load_cls():
data = {}
for lang in LANGS:
data[lang] = {}
for domain in DOMAINS:
print(f"lang: {lang}, domain: {domain}")
train = (
open(
os.path.join(
CLS_UNPROCESSED_DATA_DIR, lang, domain, "train.processed"
),
"r",
)
.read()
.splitlines()
)
test = (
open(
os.path.join(
CLS_UNPROCESSED_DATA_DIR, lang, domain, "test.processed"
),
"r",
)
.read()
.splitlines()
)
print(f"train: {len(train)}, test: {len(test)}")
data[lang][domain] = {
"train": [process_data(t) for t in train],
"test": [process_data(t) for t in test],
}
return data
def process_data(line):
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
result = re.sub(regex, subst, line, 0, re.MULTILINE)
text, label = result.split("#label#:")
# label = 0 if label == "negative" else 1
label = [1, 0] if label == "negative" else [0, 1]
return text, label
if __name__ == "__main__":
print(f"datapath: {CLS_UNPROCESSED_DATA_DIR}")
# data = load_cls()
data = load_unprocessed_cls(reduce_target_space=True)
multilingualDataset = MultilingualDataset(dataset_name="webis-cls-unprocessed")
for lang in LANGS:
# Xtr = [text["summary"] for text in data[lang]["books"]["train"]]
Xtr = [text["text"] for text in data[lang]["books"]["train"]]
Xtr += [text["text"] for text in data[lang]["dvd"]["train"]]
Xtr += [text["text"] for text in data[lang]["music"]["train"]]
Ytr =[text["rating"] for text in data[lang]["books"]["train"]]
Ytr += [text["rating"] for text in data[lang]["dvd"]["train"]]
Ytr += [text["rating"] for text in data[lang]["music"]["train"]]
Ytr = np.vstack(Ytr)
Xte = [text["text"] for text in data[lang]["books"]["test"]]
Xte += [text["text"] for text in data[lang]["dvd"]["test"]]
Xte += [text["text"] for text in data[lang]["music"]["test"]]
Yte = [text["rating"] for text in data[lang]["books"]["test"]]
Yte += [text["rating"] for text in data[lang]["dvd"]["test"]]
Yte += [text["rating"] for text in data[lang]["music"]["test"]]
Yte = np.vstack(Yte)
multilingualDataset.add(
lang=lang,
Xtr=Xtr,
Ytr=Ytr,
Xte=Xte,
Yte=Yte,
tr_ids=None,
te_ids=None,
)
multilingualDataset.save(
os.path.expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
)
)