From 298f31669d129a13062a21f1c2e8873c6bd7048f Mon Sep 17 00:00:00 2001 From: andreapdr Date: Mon, 13 Feb 2023 18:29:54 +0100 Subject: [PATCH] Data Classes for GLAMI-1M Dataset --- dataManager/glamiDataset.py | 175 ++++++++++++++++++++++++++++++++++++ main.py | 22 ++++- 2 files changed, 193 insertions(+), 4 deletions(-) create mode 100644 dataManager/glamiDataset.py diff --git a/dataManager/glamiDataset.py b/dataManager/glamiDataset.py new file mode 100644 index 0000000..9f24f6d --- /dev/null +++ b/dataManager/glamiDataset.py @@ -0,0 +1,175 @@ +import copy +import logging +import os +import zipfile +from tempfile import TemporaryFile +from typing import BinaryIO, Optional, Dict + +import requests +from tqdm import tqdm +import pandas as pd + +from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer + +# From https://github.com/glami/glami-1m/blob/main/load_dataset.py + +GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset") + +# DATASET_URL = os.environ.get( +# "DATASET_URL", +# "https://zenodo.org/record/7326406/files/GLAMI-1M-dataset.zip?download=1", +# ) +# EXTRACT_DIR = os.environ.get("EXTRACT_DIR", "/tmp/GLAMI-1M") +# DATASET_SUBDIR = "GLAMI-1M-dataset" +# DATASET_DIR = dataset_dir = EXTRACT_DIR + "/" + DATASET_SUBDIR +# MODEL_DIR = os.environ.get("MODEL_DIR", "/tmp/GLAMI-1M/models") +# EMBS_DIR = EXTRACT_DIR + "/embs" +# CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-visual" +# CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-textual" +# # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-visual" +# # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-textual" +# CLIP_EN_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-en-textual" +# GENERATED_DIR = EXTRACT_DIR + "/generated_images" + +COL_NAME_ITEM_ID = "item_id" +COL_NAME_IMAGE_ID = "image_id" +COL_NAME_IMAGE_FILE = "image_file" +COL_NAME_IMAGE_URL = "image_url" +COL_NAME_NAME = "name" +COL_NAME_DESCRIPTION = "description" +COL_NAME_GEO = "geo" +COL_NAME_CATEGORY = "category" +COL_NAME_CAT_NAME = "category_name" +COL_NAME_LABEL_SOURCE = "label_source" +COL_NAME_EMB_FILE = "emb_file" +COL_NAME_MASK_FILE = "mask_file" +DEFAULT_IMAGE_SIZE = (298, 228) + + +COUNTRY_CODE_TO_COUNTRY_NAME = { + "cz": "Czechia", + "sk": "Slovakia", + "ro": "Romania", + "gr": "Greece", + "si": "Slovenia", + "hu": "Hungary", + "hr": "Croatia", + "es": "Spain", + "lt": "Lithuania", + "lv": "Latvia", + "tr": "Turkey", + "ee": "Estonia", + "bg": "Bulgaria", +} + +COUNTRY_CODE_TO_COUNTRY_NAME_W_CC = { + name + f" ({cc})" for cc, name in COUNTRY_CODE_TO_COUNTRY_NAME +} + + +def get_dataframe(split_type: str, dataset_dir=GLAMI_DATAPATH, nrows=None): + assert split_type in ("train", "test") + if nrows is not None: + df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv", nrows=nrows) + else: + df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv") + df[COL_NAME_IMAGE_FILE] = ( + dataset_dir + "/images/" + df[COL_NAME_IMAGE_ID].astype(str) + ".jpg" + ) + df[COL_NAME_DESCRIPTION] = df[COL_NAME_DESCRIPTION].fillna("") + assert os.path.exists(df.loc[0, COL_NAME_IMAGE_FILE]) + return df[ + [ + COL_NAME_ITEM_ID, + COL_NAME_IMAGE_ID, + COL_NAME_NAME, + COL_NAME_DESCRIPTION, + COL_NAME_GEO, + COL_NAME_CAT_NAME, + COL_NAME_LABEL_SOURCE, + COL_NAME_IMAGE_FILE, + ] + ] + + +class GlamiDataset: + def __init__(self, dataset_dir, langs=None, labels=None, nrows=None): + self.dataset_dir = dataset_dir + self.data_langs = langs + self.labels = labels + self.nrows = nrows + self.multilingual_dataset = {} + + def num_labels(self): + return len(self.labels) + + def langs(self): + return self.data_langs + + def get_label_binarizer(self, labels): + mlb = LabelBinarizer() + mlb.fit(labels) + print( + f"- Label binarizer initialized with the following labels:\n{mlb.classes_}" + ) + return mlb + + def binarize_labels(self, labels): + if hasattr(self, "mlb"): + return self.mlb.transform(labels) + else: + raise ValueError("Label binarizer not initialized") + + def load_df(self, split, dataset_dir): + return get_dataframe(split, dataset_dir=dataset_dir, nrows=self.nrows) + + def build_dataset(self): + train_dataset = self.load_df("train", self.dataset_dir) + test_dataset = self.load_df("test", self.dataset_dir) + + if self.data_langs is None: + self.data_langs = train_dataset.geo.unique().tolist() + + if self.labels is None: + self.labels = train_dataset.category_name.unique().tolist() + self.mlb = self.get_label_binarizer(self.labels) + + self.multilingual_dataset = { + lang: [data_tr, data_te] + for (lang, data_tr), (_, data_te) in zip( + train_dataset.groupby("geo"), test_dataset.groupby("geo") + ) + if lang in self.data_langs + } + + def training(self): + # TODO: tolist() or ??? + lXtr = { + lang: (df.name + " " + df.description).tolist() + for lang, (df, _) in self.multilingual_dataset.items() + } + lYtr = { + lang: self.binarize_labels(df.category_name.tolist()) + for lang, (df, _) in self.multilingual_dataset.items() + } + return lXtr, lYtr + + def test(self): + lXte = { + lang: (df.name + " " + df.description).tolist() + for lang, (_, df) in self.multilingual_dataset.items() + } + lYte = { + lang: self.binarize_labels(df.category_name.tolist()) + for lang, (_, df) in self.multilingual_dataset.items() + } + return lXte, lYte + + +if __name__ == "__main__": + print("Hello glamiDataset") + dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=None) + dataset.build_dataset() + lXtr, lYtr = dataset.training() + lXte, lYte = dataset.testing() + exit(0) diff --git a/main.py b/main.py index 3d039cf..7a2bfbe 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ from time import time from dataManager.amazonDataset import AmazonDataset from dataManager.multilingualDatset import MultilingualDataset from dataManager.multiNewsDataset import MultiNewsDataset +from dataManager.glamiDataset import GlamiDataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling @@ -20,7 +21,12 @@ TODO: def get_dataset(datasetname): - assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported" + assert datasetname in [ + "multinews", + "amazon", + "rcv1-2", + "glami", + ], "dataset not supported" RCV_DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" @@ -30,6 +36,8 @@ def get_dataset(datasetname): ) MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") + GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") + if datasetname == "multinews": dataset = MultiNewsDataset( expanduser(MULTINEWS_DATAPATH), @@ -46,6 +54,9 @@ def get_dataset(datasetname): dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH) if args.nrows is not None: dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows) + elif datasetname == "glami": + dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows) + dataset.build_dataset() else: raise NotImplementedError return dataset @@ -53,8 +64,10 @@ def get_dataset(datasetname): def main(args): dataset = get_dataset(args.dataset) - if isinstance(dataset, MultilingualDataset) or isinstance( - dataset, MultiNewsDataset + if ( + isinstance(dataset, MultilingualDataset) + or isinstance(dataset, MultiNewsDataset) + or isinstance(dataset, GlamiDataset) ): lX, lY = dataset.training() lX_te, lY_te = dataset.test() @@ -109,7 +122,7 @@ def main(args): # gfun.get_config() gfun.fit(lX, lY) - if args.load_trained is None: + if args.load_trained is None and not args.nosave: gfun.save(save_first_tier=True, save_meta=True) preds = gfun.transform(lX) @@ -131,6 +144,7 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("--meta", action="store_true") + parser.add_argument("--nosave", action="store_true") # Dataset parameters ------------------- parser.add_argument("-d", "--dataset", type=str, default="multinews") parser.add_argument("--domains", type=str, default="all")