gfun_multimodal/dataManager/glamiDataset.py

176 lines
5.3 KiB
Python
Raw Normal View History

2023-02-13 18:29:54 +01:00
import copy
import logging
import os
import zipfile
from tempfile import TemporaryFile
from typing import BinaryIO, Optional, Dict
import requests
from tqdm import tqdm
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
# From https://github.com/glami/glami-1m/blob/main/load_dataset.py
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
# DATASET_URL = os.environ.get(
# "DATASET_URL",
# "https://zenodo.org/record/7326406/files/GLAMI-1M-dataset.zip?download=1",
# )
# EXTRACT_DIR = os.environ.get("EXTRACT_DIR", "/tmp/GLAMI-1M")
# DATASET_SUBDIR = "GLAMI-1M-dataset"
# DATASET_DIR = dataset_dir = EXTRACT_DIR + "/" + DATASET_SUBDIR
# MODEL_DIR = os.environ.get("MODEL_DIR", "/tmp/GLAMI-1M/models")
# EMBS_DIR = EXTRACT_DIR + "/embs"
# CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-visual"
# CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-textual"
# # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-visual"
# # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-textual"
# CLIP_EN_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-en-textual"
# GENERATED_DIR = EXTRACT_DIR + "/generated_images"
COL_NAME_ITEM_ID = "item_id"
COL_NAME_IMAGE_ID = "image_id"
COL_NAME_IMAGE_FILE = "image_file"
COL_NAME_IMAGE_URL = "image_url"
COL_NAME_NAME = "name"
COL_NAME_DESCRIPTION = "description"
COL_NAME_GEO = "geo"
COL_NAME_CATEGORY = "category"
COL_NAME_CAT_NAME = "category_name"
COL_NAME_LABEL_SOURCE = "label_source"
COL_NAME_EMB_FILE = "emb_file"
COL_NAME_MASK_FILE = "mask_file"
DEFAULT_IMAGE_SIZE = (298, 228)
COUNTRY_CODE_TO_COUNTRY_NAME = {
"cz": "Czechia",
"sk": "Slovakia",
"ro": "Romania",
"gr": "Greece",
"si": "Slovenia",
"hu": "Hungary",
"hr": "Croatia",
"es": "Spain",
"lt": "Lithuania",
"lv": "Latvia",
"tr": "Turkey",
"ee": "Estonia",
"bg": "Bulgaria",
}
COUNTRY_CODE_TO_COUNTRY_NAME_W_CC = {
name + f" ({cc})" for cc, name in COUNTRY_CODE_TO_COUNTRY_NAME
}
def get_dataframe(split_type: str, dataset_dir=GLAMI_DATAPATH, nrows=None):
assert split_type in ("train", "test")
if nrows is not None:
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv", nrows=nrows)
else:
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv")
df[COL_NAME_IMAGE_FILE] = (
dataset_dir + "/images/" + df[COL_NAME_IMAGE_ID].astype(str) + ".jpg"
)
df[COL_NAME_DESCRIPTION] = df[COL_NAME_DESCRIPTION].fillna("")
assert os.path.exists(df.loc[0, COL_NAME_IMAGE_FILE])
return df[
[
COL_NAME_ITEM_ID,
COL_NAME_IMAGE_ID,
COL_NAME_NAME,
COL_NAME_DESCRIPTION,
COL_NAME_GEO,
COL_NAME_CAT_NAME,
COL_NAME_LABEL_SOURCE,
COL_NAME_IMAGE_FILE,
]
]
class GlamiDataset:
def __init__(self, dataset_dir, langs=None, labels=None, nrows=None):
self.dataset_dir = dataset_dir
self.data_langs = langs
self.labels = labels
self.nrows = nrows
self.multilingual_dataset = {}
def num_labels(self):
return len(self.labels)
def langs(self):
return self.data_langs
def get_label_binarizer(self, labels):
mlb = LabelBinarizer()
mlb.fit(labels)
print(
f"- Label binarizer initialized with the following labels:\n{mlb.classes_}"
)
return mlb
def binarize_labels(self, labels):
if hasattr(self, "mlb"):
return self.mlb.transform(labels)
else:
raise ValueError("Label binarizer not initialized")
def load_df(self, split, dataset_dir):
return get_dataframe(split, dataset_dir=dataset_dir, nrows=self.nrows)
def build_dataset(self):
train_dataset = self.load_df("train", self.dataset_dir)
test_dataset = self.load_df("test", self.dataset_dir)
if self.data_langs is None:
self.data_langs = train_dataset.geo.unique().tolist()
if self.labels is None:
self.labels = train_dataset.category_name.unique().tolist()
self.mlb = self.get_label_binarizer(self.labels)
self.multilingual_dataset = {
lang: [data_tr, data_te]
for (lang, data_tr), (_, data_te) in zip(
train_dataset.groupby("geo"), test_dataset.groupby("geo")
)
if lang in self.data_langs
}
def training(self):
# TODO: tolist() or ???
lXtr = {
lang: (df.name + " " + df.description).tolist()
for lang, (df, _) in self.multilingual_dataset.items()
}
lYtr = {
lang: self.binarize_labels(df.category_name.tolist())
for lang, (df, _) in self.multilingual_dataset.items()
}
return lXtr, lYtr
def test(self):
lXte = {
lang: (df.name + " " + df.description).tolist()
for lang, (_, df) in self.multilingual_dataset.items()
}
lYte = {
lang: self.binarize_labels(df.category_name.tolist())
for lang, (_, df) in self.multilingual_dataset.items()
}
return lXte, lYte
if __name__ == "__main__":
print("Hello glamiDataset")
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=None)
dataset.build_dataset()
lXtr, lYtr = dataset.training()
lXte, lYte = dataset.testing()
exit(0)