diff --git a/dataManager/clsDataset.py b/dataManager/clsDataset.py index 12d0a09..330bf20 100644 --- a/dataManager/clsDataset.py +++ b/dataManager/clsDataset.py @@ -20,6 +20,10 @@ subst = "" def load_unprocessed_cls(reduce_target_space=False): data = {} + data_tr = [] + data_te = [] + c_tr = 0 + c_te = 0 for lang in LANGS: data[lang] = {} for domain in DOMAINS: @@ -43,17 +47,29 @@ def load_unprocessed_cls(reduce_target_space=False): else: new_rating = 2 rating[new_rating - 1] = 1 + # rating = new_rating else: rating = np.zeros(5, dtype=int) rating[int(float(child.find("rating").text)) - 1] = 1 + # rating = new_rating + # if split == "train": + # target_data = data_tr + # current_count = len(target_data) + # c_tr = +1 + # else: + # target_data = data_te + # current_count = len(target_data) + # c_te = +1 domain_data.append( + # target_data.append( { "asin": child.find("asin").text if child.find("asin") is not None else None, - "category": child.find("category").text - if child.find("category") is not None - else None, + # "category": child.find("category").text + # if child.find("category") is not None + # else None, + "category": domain, # "rating": child.find("rating").text # if child.find("rating") is not None # else None, @@ -67,6 +83,7 @@ def load_unprocessed_cls(reduce_target_space=False): "summary": child.find("summary").text if child.find("summary") is not None else None, + "lang": lang, } ) data[lang][domain].update({split: domain_data}) @@ -125,21 +142,25 @@ if __name__ == "__main__": for lang in LANGS: # Xtr = [text["summary"] for text in data[lang]["books"]["train"]] Xtr = [text["text"] for text in data[lang]["books"]["train"]] - Ytr = np.vstack([text["rating"] for text in data[lang]["books"]["train"]]) + Xtr += [text["text"] for text in data[lang]["dvd"]["train"]] + Xtr += [text["text"] for text in data[lang]["music"]["train"]] + + Ytr =[text["rating"] for text in data[lang]["books"]["train"]] + Ytr += [text["rating"] for text in data[lang]["dvd"]["train"]] + Ytr += [text["rating"] for text in data[lang]["music"]["train"]] + + Ytr = np.vstack(Ytr) - # Xte = [text["summary"] for text in data[lang]["books"]["test"]] Xte = [text["text"] for text in data[lang]["books"]["test"]] - Yte = np.vstack([text["rating"] for text in data[lang]["books"]["test"]]) + Xte += [text["text"] for text in data[lang]["dvd"]["test"]] + Xte += [text["text"] for text in data[lang]["music"]["test"]] - # for lang in LANGS: - # # TODO: just using book domain atm - # Xtr = [text[0] for text in data[lang]["books"]["train"]] - # # Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1) - # Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]]) - # Xte = [text[0] for text in data[lang]["books"]["test"]] - # # Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) - # Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]]) + Yte = [text["rating"] for text in data[lang]["books"]["test"]] + Yte += [text["rating"] for text in data[lang]["dvd"]["test"]] + Yte += [text["rating"] for text in data[lang]["music"]["test"]] + + Yte = np.vstack(Yte) multilingualDataset.add( lang=lang, @@ -152,6 +173,6 @@ if __name__ == "__main__": ) multilingualDataset.save( os.path.expanduser( - "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl" + "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl" ) ) diff --git a/dataManager/utils.py b/dataManager/utils.py index b5ee50b..6270b5f 100644 --- a/dataManager/utils.py +++ b/dataManager/utils.py @@ -39,7 +39,7 @@ def get_dataset(dataset_name, args): GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") WEBIS_CLS = expanduser( - "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-book.pkl" + "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl" ) if dataset_name == "multinews":