webis-cls unprocessed manager

This commit is contained in:
Andrea Pedrotti 2023-06-22 11:32:15 +02:00
parent de98926d00
commit 9437ccc837
2 changed files with 37 additions and 16 deletions

View File

@ -20,6 +20,10 @@ subst = ""
def load_unprocessed_cls(reduce_target_space=False): def load_unprocessed_cls(reduce_target_space=False):
data = {} data = {}
data_tr = []
data_te = []
c_tr = 0
c_te = 0
for lang in LANGS: for lang in LANGS:
data[lang] = {} data[lang] = {}
for domain in DOMAINS: for domain in DOMAINS:
@ -43,17 +47,29 @@ def load_unprocessed_cls(reduce_target_space=False):
else: else:
new_rating = 2 new_rating = 2
rating[new_rating - 1] = 1 rating[new_rating - 1] = 1
# rating = new_rating
else: else:
rating = np.zeros(5, dtype=int) rating = np.zeros(5, dtype=int)
rating[int(float(child.find("rating").text)) - 1] = 1 rating[int(float(child.find("rating").text)) - 1] = 1
# rating = new_rating
# if split == "train":
# target_data = data_tr
# current_count = len(target_data)
# c_tr = +1
# else:
# target_data = data_te
# current_count = len(target_data)
# c_te = +1
domain_data.append( domain_data.append(
# target_data.append(
{ {
"asin": child.find("asin").text "asin": child.find("asin").text
if child.find("asin") is not None if child.find("asin") is not None
else None, else None,
"category": child.find("category").text # "category": child.find("category").text
if child.find("category") is not None # if child.find("category") is not None
else None, # else None,
"category": domain,
# "rating": child.find("rating").text # "rating": child.find("rating").text
# if child.find("rating") is not None # if child.find("rating") is not None
# else None, # else None,
@ -67,6 +83,7 @@ def load_unprocessed_cls(reduce_target_space=False):
"summary": child.find("summary").text "summary": child.find("summary").text
if child.find("summary") is not None if child.find("summary") is not None
else None, else None,
"lang": lang,
} }
) )
data[lang][domain].update({split: domain_data}) data[lang][domain].update({split: domain_data})
@ -125,21 +142,25 @@ if __name__ == "__main__":
for lang in LANGS: for lang in LANGS:
# Xtr = [text["summary"] for text in data[lang]["books"]["train"]] # Xtr = [text["summary"] for text in data[lang]["books"]["train"]]
Xtr = [text["text"] for text in data[lang]["books"]["train"]] Xtr = [text["text"] for text in data[lang]["books"]["train"]]
Ytr = np.vstack([text["rating"] for text in data[lang]["books"]["train"]]) Xtr += [text["text"] for text in data[lang]["dvd"]["train"]]
Xtr += [text["text"] for text in data[lang]["music"]["train"]]
Ytr =[text["rating"] for text in data[lang]["books"]["train"]]
Ytr += [text["rating"] for text in data[lang]["dvd"]["train"]]
Ytr += [text["rating"] for text in data[lang]["music"]["train"]]
Ytr = np.vstack(Ytr)
# Xte = [text["summary"] for text in data[lang]["books"]["test"]]
Xte = [text["text"] for text in data[lang]["books"]["test"]] Xte = [text["text"] for text in data[lang]["books"]["test"]]
Yte = np.vstack([text["rating"] for text in data[lang]["books"]["test"]]) Xte += [text["text"] for text in data[lang]["dvd"]["test"]]
Xte += [text["text"] for text in data[lang]["music"]["test"]]
# for lang in LANGS:
# # TODO: just using book domain atm
# Xtr = [text[0] for text in data[lang]["books"]["train"]]
# # Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
# Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
# Xte = [text[0] for text in data[lang]["books"]["test"]] Yte = [text["rating"] for text in data[lang]["books"]["test"]]
# # Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) Yte += [text["rating"] for text in data[lang]["dvd"]["test"]]
# Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]]) Yte += [text["rating"] for text in data[lang]["music"]["test"]]
Yte = np.vstack(Yte)
multilingualDataset.add( multilingualDataset.add(
lang=lang, lang=lang,
@ -152,6 +173,6 @@ if __name__ == "__main__":
) )
multilingualDataset.save( multilingualDataset.save(
os.path.expanduser( os.path.expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl" "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
) )
) )

View File

@ -39,7 +39,7 @@ def get_dataset(dataset_name, args):
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
WEBIS_CLS = expanduser( WEBIS_CLS = expanduser(
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl" "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
) )
if dataset_name == "multinews": if dataset_name == "multinews":