import gzip import os import re import warnings from argparse import ArgumentParser from collections import Counter import numpy as np from bs4 import BeautifulSoup from sklearn.preprocessing import MultiLabelBinarizer from plotters.distributions import plot_distribution # TODO: AmazonDataset should be a instanc of MultimodalDataset warnings.filterwarnings("ignore", category=UserWarning, module="bs4") warnings.filterwarnings("ignore", category=UserWarning, module="sklearn") BASEPATH = "/home/moreo/Datasets/raw" with open("dataManager/excluded.csv", "r") as f: EXCLUDED = f.read().splitlines() REGEX = re.compile(r"\s{2,}", re.MULTILINE) def parse(dataset_name, ext="json.gz", nrows=0): dataset_name = dataset_name.replace(" ", "_") meta_path = os.path.join(BASEPATH, f"meta_{dataset_name}.{ext}") path = os.path.join(BASEPATH, f"{dataset_name}.{ext}") mapper = {"false": False, "true": True} data = [] metadata = [] _data = gzip.open(path, "r") _metadata = gzip.open(meta_path, "r") for i, (d, m) in enumerate(zip(_data, _metadata)): data.append(eval(d.replace(b"&", b"&"), mapper)) metadata.append(eval(m.replace(b"&", b"&"), mapper)) if i + 1 == nrows: break return data, metadata def get_categories(data, min_count=0): if data[0].get("category", None) is None: return [], set() categories = [] for item in data: if item["category"] != "": categories.extend(item["category"]) categories = list(filter(lambda x: x not in EXCLUDED, categories)) # return categories, sorted(set(categories)) return categories, _filter_counter(Counter(categories), min_count) def _filter_counter(counter, min_count): return {k: v for k, v in counter.items() if v >= min_count} def get_main_cat(data, min_count=0): if data[0].get("main_cat", None) is None: return [], set() main_cats = [item["main_cat"] for item in data if item["main_cat"] != ""] main_cats = list(filter(lambda x: x not in EXCLUDED, main_cats)) # return main_cats, sorted(set(main_cats)) return main_cats, _filter_counter(Counter(main_cats), min_count) def filter_sample_with_images(metadata): # TODO: check whether images are really available and store them locally # print(f"(Pre-filter) Total items: {len(metadata)}") data = [] for i, m in enumerate(metadata): if "imageURL" not in m.keys(): continue if len(m["imageURL"]) != 0 or len(m["imageURLHighRes"]) != 0: data.append(m) # print(f"(Post-filter) Total items: {len(data)}") return data def select_description(descriptions): """ Some items have multiple descriptions (len(item["description"]) > 1). Most of these descriptions are just empty strings. Some items instead actually have multiple strings describing them At the moment, we rely on a simple heuristic: select the longest string and use it the only description. """ if len(descriptions) == 0: return [""] return [max(descriptions, key=len)] def build_product_json(metadata, binarizer): data = [] for item in metadata: if len(item["description"]) != 1: item["description"] = select_description(item["description"]) product = { "asin": item["asin"], "title": item["title"], "description": item["description"], # TODO: some items have multiple descriptions (len(item["description"]) > 1)) "cleaned_description": clean_description( BeautifulSoup( item["title"] + ". " + item["description"][0], features="html.parser", ).text ), # TODO: is it faster to call transform on the whole dataset? "main_category": item["main_cat"], "categories": item["category"], "all_categories": _get_cats(item["main_cat"], item["category"]), "vect_categories": binarizer.transform( [_get_cats(item["main_cat"], item["category"])] )[0], } data.append(product) return data def _get_cats(main_cat, cats): return [main_cat] + cats def get_label_binarizer(cats): mlb = MultiLabelBinarizer() mlb.fit([cats]) return mlb def clean_description(description): description = re.sub(REGEX, " ", description) description = description.rstrip() description = description.replace("\t", "") description = description.replace("\n", " ") return description def construct_target_matrix(data): return np.stack([d["vect_categories"] for d in data], axis=0) def get_all_classes(counter_cats, counter_sub_cats): if len(counter_cats) == 0: return counter_sub_cats.keys() elif len(counter_sub_cats) == 0: return counter_cats.keys() else: return list(counter_cats.keys()) + list(counter_sub_cats.keys()) class AmazonDataset: def __init__( self, domains=["Appliances", "Automotive", "Movies and TV"], basepath="/home/moreo/Datasets/raw", min_count=10, max_labels=50, nrows=1000, ): print(f"[Init AmazonDataset]") print(f"- Domains: {domains}") self.REGEX = re.compile(r"\s{2,}", re.MULTILINE) with open("dataManager/excluded.csv", "r") as f: self.EXCLUDED = f.read().splitlines() self.basepath = basepath self.domains = self.parse_domains(domains) self.nrows = nrows self.min_count = min_count self.max_labels = max_labels self.len_data = 0 self.domain_data = self.load_data() self.labels, self.domain_labels = self.get_all_cats() self.label_binarizer = get_label_binarizer(self.labels) self.vectorized_labels = self.vecorize_labels() self.dX = self.construct_data_matrix() self.dY = self.construct_target_matrix() self.langs = ["en"] def parse_domains(self, domains): with open("amazon_categories.txt", "r") as f: all_domains = f.read().splitlines() if domains == "all": return all_domains else: assert all([d in all_domains for d in domains]), "Invalid domain name" return domains def parse(self, dataset_name, nrows, ext="json.gz"): dataset_name = dataset_name.replace(" ", "_") meta_path = os.path.join(self.basepath, f"meta_{dataset_name}.{ext}") path = os.path.join(self.basepath, f"{dataset_name}.{ext}") mapper = {"false": False, "true": True} data = [] metadata = [] _data = gzip.open(path, "r") _metadata = gzip.open(meta_path, "r") for i, (d, m) in enumerate(zip(_data, _metadata)): data.append(eval(d.replace(b"&", b"&"), mapper)) metadata.append(eval(m.replace(b"&", b"&"), mapper)) if i + 1 == nrows: break return data, metadata def load_data(self): print(f"- Loading up to {self.nrows} items per domain") domain_data = {} for domain in self.domains: _, metadata = self.parse(domain, nrows=self.nrows) metadata = filter_sample_with_images(metadata) domain_data[domain] = self.build_product_scheme(metadata) self.len_data += len(metadata) print(f"- Loaded {self.len_data} items") return domain_data def get_all_cats(self): assert len(self.domain_data) != 0, "Load data first" labels = set() domain_labels = {} for domain, data in self.domain_data.items(): _, counter_cats = self._get_counter_cats(data, self.min_count) labels.update(counter_cats.keys()) domain_labels[domain] = counter_cats print(f"- Found {len(labels)} labels") return labels, domain_labels def export_to_torch(self): pass def get_label_binarizer(self): mlb = MultiLabelBinarizer() mlb.fit([self.labels]) return mlb def vecorize_labels(self): for domain, data in self.domain_data.items(): for item in data: item["vect_categories"] = self.label_binarizer.transform( [item["all_categories"]] )[0] def build_product_scheme(self, metadata): data = [] for item in metadata: if len(item["description"]) != 1: _desc = self._select_description(item["description"]) else: _desc = item["description"][0] product = { "asin": item["asin"], "title": item["title"], "description": _desc, # TODO: some items have multiple descriptions (len(item["description"]) > 1)) "cleaned_text": self._clean_description( BeautifulSoup( item["title"] + ". " + _desc, features="html.parser", ).text ), # TODO: is it faster to call transform on the whole dataset? "main_category": item["main_cat"], "categories": item["category"], "all_categories": self._get_cats(item["main_cat"], item["category"]), # "vect_categories": binarizer.transform( # [_get_cats(item["main_cat"], item["category"])] # )[0], } data.append(product) return data def construct_data_matrix(self): dX = {} for domain, data in self.domain_data.items(): dX[domain] = [d["cleaned_text"] for d in data] return dX def construct_target_matrix(self): dY = {} for domain, data in self.domain_data.items(): dY[domain] = np.stack([d["vect_categories"] for d in data], axis=0) return dY def get_overall_label_matrix(self): assert hasattr(self, "label_matrices"), "Init label matrices first" return np.vstack([x for x in self.dY.values()]) def _get_counter_cats(self, data, min_count): cats = [] for item in data: cats.extend(item["all_categories"]) cats = list(filter(lambda x: x not in self.EXCLUDED, cats)) return cats, self._filter_counter(Counter(cats), min_count) def _filter_counter(self, counter, min_count): return {k: v for k, v in counter.items() if v >= min_count} def _clean_description(self, description): description = re.sub(self.REGEX, " ", description) description = description.rstrip() description = description.replace("\t", "") description = description.replace("\n", " ") return description def _get_cats(self, main_cat, cats): return [main_cat] + cats def _select_description(self, descriptions) -> str: """ Some items have multiple descriptions (len(item["description"]) > 1). Most of these descriptions are just empty strings. Some items instead actually have multiple strings describing them At the moment, we rely on a simple heuristic: select the longest string and use it the only description. """ if len(descriptions) == 0: return "" return max(descriptions, key=len) def plot_label_distribution(self): overall_mat = self.get_overall_label_matrix() plot_distribution( np.arange(len(self.labels)), np.sum(overall_mat, axis=0), title="Amazon Dataset", labels=self.labels, notes=overall_mat.shape, max_labels=args.max_labels, figsize=(10, 10), save=True, path="out", ) def plot_per_domain_label_distribution(self): for domain, matrix in self.vecorize_labels: pass def main(args): dataset = AmazonDataset( domains=args.domains, nrows=args.nrows, min_count=args.min_count, max_labels=args.max_labels, ) dataset.plot_label_distribution() exit() if __name__ == "__main__": import sys sys.path.append("/home/andreapdr/devel/gFunMultiModal/") parser = ArgumentParser() parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=10000) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) args = parser.parse_args() main(args)