gfun_multimodal/dataManager/glamiDataset.py

194 lines
6.0 KiB
Python

import copy
import logging
import os
import zipfile
from tempfile import TemporaryFile
from typing import BinaryIO, Optional, Dict
import requests
from tqdm import tqdm
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
# From https://github.com/glami/glami-1m/blob/main/load_dataset.py
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
# DATASET_URL = os.environ.get(
# "DATASET_URL",
# "https://zenodo.org/record/7326406/files/GLAMI-1M-dataset.zip?download=1",
# )
# EXTRACT_DIR = os.environ.get("EXTRACT_DIR", "/tmp/GLAMI-1M")
# DATASET_SUBDIR = "GLAMI-1M-dataset"
# DATASET_DIR = dataset_dir = EXTRACT_DIR + "/" + DATASET_SUBDIR
# MODEL_DIR = os.environ.get("MODEL_DIR", "/tmp/GLAMI-1M/models")
# EMBS_DIR = EXTRACT_DIR + "/embs"
# CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-visual"
# CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-textual"
# # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-visual"
# # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-textual"
# CLIP_EN_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-en-textual"
# GENERATED_DIR = EXTRACT_DIR + "/generated_images"
COL_NAME_ITEM_ID = "item_id"
COL_NAME_IMAGE_ID = "image_id"
COL_NAME_IMAGE_FILE = "image_file"
COL_NAME_IMAGE_URL = "image_url"
COL_NAME_NAME = "name"
COL_NAME_DESCRIPTION = "description"
COL_NAME_GEO = "geo"
COL_NAME_CATEGORY = "category"
COL_NAME_CAT_NAME = "category_name"
COL_NAME_LABEL_SOURCE = "label_source"
COL_NAME_EMB_FILE = "emb_file"
COL_NAME_MASK_FILE = "mask_file"
DEFAULT_IMAGE_SIZE = (298, 228)
COUNTRY_CODE_TO_COUNTRY_NAME = {
"cz": "Czechia",
"sk": "Slovakia",
"ro": "Romania",
"gr": "Greece",
"si": "Slovenia",
"hu": "Hungary",
"hr": "Croatia",
"es": "Spain",
"lt": "Lithuania",
"lv": "Latvia",
"tr": "Turkey",
"ee": "Estonia",
"bg": "Bulgaria",
}
COUNTRY_CODE_TO_COUNTRY_NAME_W_CC = {
name + f" ({cc})" for cc, name in COUNTRY_CODE_TO_COUNTRY_NAME
}
def get_dataframe(split_type: str, dataset_dir=GLAMI_DATAPATH, nrows=None):
assert split_type in ("train", "test")
if nrows is not None:
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv", nrows=nrows)
else:
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv")
df[COL_NAME_IMAGE_FILE] = (
dataset_dir + "/images/" + df[COL_NAME_IMAGE_ID].astype(str) + ".jpg"
)
df[COL_NAME_DESCRIPTION] = df[COL_NAME_DESCRIPTION].fillna("")
assert os.path.exists(df.loc[0, COL_NAME_IMAGE_FILE])
return df[
[
COL_NAME_ITEM_ID,
COL_NAME_IMAGE_ID,
COL_NAME_NAME,
COL_NAME_DESCRIPTION,
COL_NAME_GEO,
COL_NAME_CAT_NAME,
COL_NAME_LABEL_SOURCE,
COL_NAME_IMAGE_FILE,
]
]
class GlamiDataset:
def __init__(self, dataset_dir, langs=None, labels=None, nrows=None):
self.dataset_dir = dataset_dir
self.data_langs = langs
self.labels = labels
self.nrows = nrows
self.multilingual_dataset = {}
"""
self.multilingual_multimodal_dataset = {
lang: {
"text": {txt_data},
"image": {img_data},
}
}
TODO: if we decide to do this, we need to change both the
training (e.g. vectorizer should call "text") and also the
multilingual unimodal dataset (to include the field "text" only).
BUT this will be a pain when we split/shuffle the datasets.
I think it is better to have smt like this:
self.ml_mm_dataset = {
"lang": (txt_data, img_data)
}
but then also the unimodal dataset should have a "lang": (txt_data, _) value
"""
def num_labels(self):
return len(self.labels)
def langs(self):
return self.data_langs
def get_label_binarizer(self, labels):
mlb = LabelBinarizer()
mlb.fit(labels)
print(f"- Label binarizer initialized with {len(mlb.classes_)} labels")
return mlb
def binarize_labels(self, labels):
if hasattr(self, "mlb"):
return self.mlb.transform(labels)
else:
raise ValueError("Label binarizer not initialized")
def load_df(self, split, dataset_dir):
return get_dataframe(split, dataset_dir=dataset_dir, nrows=self.nrows)
def build_dataset(self):
train_dataset = self.load_df("train", self.dataset_dir)
test_dataset = self.load_df("test", self.dataset_dir)
if self.data_langs is None:
self.data_langs = train_dataset.geo.unique().tolist()
if self.labels is None:
self.labels = train_dataset.category_name.unique().tolist()
self.mlb = self.get_label_binarizer(self.labels)
self.multilingual_dataset = {
lang: [data_tr, data_te]
for (lang, data_tr), (_, data_te) in zip(
train_dataset.groupby("geo"), test_dataset.groupby("geo")
)
if lang in self.data_langs
}
def training(self):
# TODO: tolist() or ???
lXtr = {
lang: ((df.name + " " + df.description).tolist(), df.image_file.tolist())
for lang, (df, _) in self.multilingual_dataset.items()
}
lYtr = {
lang: self.binarize_labels(df.category_name.tolist())
for lang, (df, _) in self.multilingual_dataset.items()
}
return lXtr, lYtr
def test(self):
lXte = {
lang: ((df.name + " " + df.description).tolist(), df.image_file.tolist())
for lang, (_, df) in self.multilingual_dataset.items()
}
lYte = {
lang: self.binarize_labels(df.category_name.tolist())
for lang, (_, df) in self.multilingual_dataset.items()
}
return lXte, lYte
if __name__ == "__main__":
print("Hello glamiDataset")
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=None)
dataset.build_dataset()
lXtr, lYtr = dataset.training()
lXte, lYte = dataset.testing()
exit(0)