webis-cls unprocessed manager

This commit is contained in:
Andrea Pedrotti 2023-06-22 11:32:15 +02:00
parent de98926d00
commit 9437ccc837
2 changed files with 37 additions and 16 deletions

View File

@ -20,6 +20,10 @@ subst = ""
def load_unprocessed_cls(reduce_target_space=False):
data = {}
data_tr = []
data_te = []
c_tr = 0
c_te = 0
for lang in LANGS:
data[lang] = {}
for domain in DOMAINS:
@ -43,17 +47,29 @@ def load_unprocessed_cls(reduce_target_space=False):
else:
new_rating = 2
rating[new_rating - 1] = 1
# rating = new_rating
else:
rating = np.zeros(5, dtype=int)
rating[int(float(child.find("rating").text)) - 1] = 1
# rating = new_rating
# if split == "train":
# target_data = data_tr
# current_count = len(target_data)
# c_tr = +1
# else:
# target_data = data_te
# current_count = len(target_data)
# c_te = +1
domain_data.append(
# target_data.append(
{
"asin": child.find("asin").text
if child.find("asin") is not None
else None,
"category": child.find("category").text
if child.find("category") is not None
else None,
# "category": child.find("category").text
# if child.find("category") is not None
# else None,
"category": domain,
# "rating": child.find("rating").text
# if child.find("rating") is not None
# else None,
@ -67,6 +83,7 @@ def load_unprocessed_cls(reduce_target_space=False):
"summary": child.find("summary").text
if child.find("summary") is not None
else None,
"lang": lang,
}
)
data[lang][domain].update({split: domain_data})
@ -125,21 +142,25 @@ if __name__ == "__main__":
for lang in LANGS:
# Xtr = [text["summary"] for text in data[lang]["books"]["train"]]
Xtr = [text["text"] for text in data[lang]["books"]["train"]]
Ytr = np.vstack([text["rating"] for text in data[lang]["books"]["train"]])
Xtr += [text["text"] for text in data[lang]["dvd"]["train"]]
Xtr += [text["text"] for text in data[lang]["music"]["train"]]
Ytr =[text["rating"] for text in data[lang]["books"]["train"]]
Ytr += [text["rating"] for text in data[lang]["dvd"]["train"]]
Ytr += [text["rating"] for text in data[lang]["music"]["train"]]
Ytr = np.vstack(Ytr)
# Xte = [text["summary"] for text in data[lang]["books"]["test"]]
Xte = [text["text"] for text in data[lang]["books"]["test"]]
Yte = np.vstack([text["rating"] for text in data[lang]["books"]["test"]])
Xte += [text["text"] for text in data[lang]["dvd"]["test"]]
Xte += [text["text"] for text in data[lang]["music"]["test"]]
# for lang in LANGS:
# # TODO: just using book domain atm
# Xtr = [text[0] for text in data[lang]["books"]["train"]]
# # Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
# Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
# Xte = [text[0] for text in data[lang]["books"]["test"]]
# # Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
# Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
Yte = [text["rating"] for text in data[lang]["books"]["test"]]
Yte += [text["rating"] for text in data[lang]["dvd"]["test"]]
Yte += [text["rating"] for text in data[lang]["music"]["test"]]
Yte = np.vstack(Yte)
multilingualDataset.add(
lang=lang,
@ -152,6 +173,6 @@ if __name__ == "__main__":
)
multilingualDataset.save(
os.path.expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl"
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
)
)

View File

@ -39,7 +39,7 @@ def get_dataset(dataset_name, args):
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
WEBIS_CLS = expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl"
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
)
if dataset_name == "multinews":