Data Classes for GLAMI-1M Dataset

This commit is contained in:
Andrea Pedrotti 2023-02-13 18:29:54 +01:00
parent 7ed98346a5
commit 298f31669d
2 changed files with 193 additions and 4 deletions

175
dataManager/glamiDataset.py Normal file
View File

@ -0,0 +1,175 @@
import copy
import logging
import os
import zipfile
from tempfile import TemporaryFile
from typing import BinaryIO, Optional, Dict
import requests
from tqdm import tqdm
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
# From https://github.com/glami/glami-1m/blob/main/load_dataset.py
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
# DATASET_URL = os.environ.get(
# "DATASET_URL",
# "https://zenodo.org/record/7326406/files/GLAMI-1M-dataset.zip?download=1",
# )
# EXTRACT_DIR = os.environ.get("EXTRACT_DIR", "/tmp/GLAMI-1M")
# DATASET_SUBDIR = "GLAMI-1M-dataset"
# DATASET_DIR = dataset_dir = EXTRACT_DIR + "/" + DATASET_SUBDIR
# MODEL_DIR = os.environ.get("MODEL_DIR", "/tmp/GLAMI-1M/models")
# EMBS_DIR = EXTRACT_DIR + "/embs"
# CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-visual"
# CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-textual"
# # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-visual"
# # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-textual"
# CLIP_EN_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-en-textual"
# GENERATED_DIR = EXTRACT_DIR + "/generated_images"
COL_NAME_ITEM_ID = "item_id"
COL_NAME_IMAGE_ID = "image_id"
COL_NAME_IMAGE_FILE = "image_file"
COL_NAME_IMAGE_URL = "image_url"
COL_NAME_NAME = "name"
COL_NAME_DESCRIPTION = "description"
COL_NAME_GEO = "geo"
COL_NAME_CATEGORY = "category"
COL_NAME_CAT_NAME = "category_name"
COL_NAME_LABEL_SOURCE = "label_source"
COL_NAME_EMB_FILE = "emb_file"
COL_NAME_MASK_FILE = "mask_file"
DEFAULT_IMAGE_SIZE = (298, 228)
COUNTRY_CODE_TO_COUNTRY_NAME = {
"cz": "Czechia",
"sk": "Slovakia",
"ro": "Romania",
"gr": "Greece",
"si": "Slovenia",
"hu": "Hungary",
"hr": "Croatia",
"es": "Spain",
"lt": "Lithuania",
"lv": "Latvia",
"tr": "Turkey",
"ee": "Estonia",
"bg": "Bulgaria",
}
COUNTRY_CODE_TO_COUNTRY_NAME_W_CC = {
name + f" ({cc})" for cc, name in COUNTRY_CODE_TO_COUNTRY_NAME
}
def get_dataframe(split_type: str, dataset_dir=GLAMI_DATAPATH, nrows=None):
assert split_type in ("train", "test")
if nrows is not None:
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv", nrows=nrows)
else:
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv")
df[COL_NAME_IMAGE_FILE] = (
dataset_dir + "/images/" + df[COL_NAME_IMAGE_ID].astype(str) + ".jpg"
)
df[COL_NAME_DESCRIPTION] = df[COL_NAME_DESCRIPTION].fillna("")
assert os.path.exists(df.loc[0, COL_NAME_IMAGE_FILE])
return df[
[
COL_NAME_ITEM_ID,
COL_NAME_IMAGE_ID,
COL_NAME_NAME,
COL_NAME_DESCRIPTION,
COL_NAME_GEO,
COL_NAME_CAT_NAME,
COL_NAME_LABEL_SOURCE,
COL_NAME_IMAGE_FILE,
]
]
class GlamiDataset:
def __init__(self, dataset_dir, langs=None, labels=None, nrows=None):
self.dataset_dir = dataset_dir
self.data_langs = langs
self.labels = labels
self.nrows = nrows
self.multilingual_dataset = {}
def num_labels(self):
return len(self.labels)
def langs(self):
return self.data_langs
def get_label_binarizer(self, labels):
mlb = LabelBinarizer()
mlb.fit(labels)
print(
f"- Label binarizer initialized with the following labels:\n{mlb.classes_}"
)
return mlb
def binarize_labels(self, labels):
if hasattr(self, "mlb"):
return self.mlb.transform(labels)
else:
raise ValueError("Label binarizer not initialized")
def load_df(self, split, dataset_dir):
return get_dataframe(split, dataset_dir=dataset_dir, nrows=self.nrows)
def build_dataset(self):
train_dataset = self.load_df("train", self.dataset_dir)
test_dataset = self.load_df("test", self.dataset_dir)
if self.data_langs is None:
self.data_langs = train_dataset.geo.unique().tolist()
if self.labels is None:
self.labels = train_dataset.category_name.unique().tolist()
self.mlb = self.get_label_binarizer(self.labels)
self.multilingual_dataset = {
lang: [data_tr, data_te]
for (lang, data_tr), (_, data_te) in zip(
train_dataset.groupby("geo"), test_dataset.groupby("geo")
)
if lang in self.data_langs
}
def training(self):
# TODO: tolist() or ???
lXtr = {
lang: (df.name + " " + df.description).tolist()
for lang, (df, _) in self.multilingual_dataset.items()
}
lYtr = {
lang: self.binarize_labels(df.category_name.tolist())
for lang, (df, _) in self.multilingual_dataset.items()
}
return lXtr, lYtr
def test(self):
lXte = {
lang: (df.name + " " + df.description).tolist()
for lang, (_, df) in self.multilingual_dataset.items()
}
lYte = {
lang: self.binarize_labels(df.category_name.tolist())
for lang, (_, df) in self.multilingual_dataset.items()
}
return lXte, lYte
if __name__ == "__main__":
print("Hello glamiDataset")
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=None)
dataset.build_dataset()
lXtr, lYtr = dataset.training()
lXte, lYte = dataset.testing()
exit(0)

22
main.py
View File

@ -6,6 +6,7 @@ from time import time
from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDatset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.glamiDataset import GlamiDataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
@ -20,7 +21,12 @@ TODO:
def get_dataset(datasetname):
assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
assert datasetname in [
"multinews",
"amazon",
"rcv1-2",
"glami",
], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
@ -30,6 +36,8 @@ def get_dataset(datasetname):
)
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
if datasetname == "multinews":
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
@ -46,6 +54,9 @@ def get_dataset(datasetname):
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
if args.nrows is not None:
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
elif datasetname == "glami":
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows)
dataset.build_dataset()
else:
raise NotImplementedError
return dataset
@ -53,8 +64,10 @@ def get_dataset(datasetname):
def main(args):
dataset = get_dataset(args.dataset)
if isinstance(dataset, MultilingualDataset) or isinstance(
dataset, MultiNewsDataset
if (
isinstance(dataset, MultilingualDataset)
or isinstance(dataset, MultiNewsDataset)
or isinstance(dataset, GlamiDataset)
):
lX, lY = dataset.training()
lX_te, lY_te = dataset.test()
@ -109,7 +122,7 @@ def main(args):
# gfun.get_config()
gfun.fit(lX, lY)
if args.load_trained is None:
if args.load_trained is None and not args.nosave:
gfun.save(save_first_tier=True, save_meta=True)
preds = gfun.transform(lX)
@ -131,6 +144,7 @@ if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-l", "--load_trained", type=str, default=None)
parser.add_argument("--meta", action="store_true")
parser.add_argument("--nosave", action="store_true")
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="multinews")
parser.add_argument("--domains", type=str, default="all")