gfun_multimodal/dataManager/clsDataset.py

87 lines
2.6 KiB
Python

import sys
import os
sys.path.append(os.getcwd())
import numpy as np
import re
from dataManager.multilingualDataset import MultilingualDataset
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
LANGS = ["de", "en", "fr", "jp"]
DOMAINS = ["books", "dvd", "music"]
regex = r":\d+"
subst = ""
def load_cls():
data = {}
for lang in LANGS:
data[lang] = {}
for domain in DOMAINS:
print(f"lang: {lang}, domain: {domain}")
train = (
open(
os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
),
"r",
)
.read()
.splitlines()
)
test = (
open(
os.path.join(
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
),
"r",
)
.read()
.splitlines()
)
print(f"train: {len(train)}, test: {len(test)}")
data[lang][domain] = {
"train": [process_data(t) for t in train],
"test": [process_data(t) for t in test],
}
return data
def process_data(line):
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
result = re.sub(regex, subst, line, 0, re.MULTILINE)
text, label = result.split("#label#:")
# label = 0 if label == "negative" else 1
label = [1, 0] if label == "negative" else [0, 1]
return text, label
if __name__ == "__main__":
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
data = load_cls()
multilingualDataset = MultilingualDataset(dataset_name="cls")
for lang in LANGS:
# TODO: just using book domain atm
Xtr = [text[0] for text in data[lang]["books"]["train"]]
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
Xte = [text[0] for text in data[lang]["books"]["test"]]
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
multilingualDataset.add(
lang=lang,
Xtr=Xtr,
Ytr=Ytr,
Xte=Xte,
Yte=Yte,
tr_ids=None,
te_ids=None,
)
multilingualDataset.save(
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
)