176 lines
5.3 KiB
Python
176 lines
5.3 KiB
Python
import copy
|
|
import logging
|
|
import os
|
|
import zipfile
|
|
from tempfile import TemporaryFile
|
|
from typing import BinaryIO, Optional, Dict
|
|
|
|
import requests
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
|
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
|
|
|
# From https://github.com/glami/glami-1m/blob/main/load_dataset.py
|
|
|
|
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
|
|
|
# DATASET_URL = os.environ.get(
|
|
# "DATASET_URL",
|
|
# "https://zenodo.org/record/7326406/files/GLAMI-1M-dataset.zip?download=1",
|
|
# )
|
|
# EXTRACT_DIR = os.environ.get("EXTRACT_DIR", "/tmp/GLAMI-1M")
|
|
# DATASET_SUBDIR = "GLAMI-1M-dataset"
|
|
# DATASET_DIR = dataset_dir = EXTRACT_DIR + "/" + DATASET_SUBDIR
|
|
# MODEL_DIR = os.environ.get("MODEL_DIR", "/tmp/GLAMI-1M/models")
|
|
# EMBS_DIR = EXTRACT_DIR + "/embs"
|
|
# CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-visual"
|
|
# CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-textual"
|
|
# # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-visual"
|
|
# # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-textual"
|
|
# CLIP_EN_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-en-textual"
|
|
# GENERATED_DIR = EXTRACT_DIR + "/generated_images"
|
|
|
|
COL_NAME_ITEM_ID = "item_id"
|
|
COL_NAME_IMAGE_ID = "image_id"
|
|
COL_NAME_IMAGE_FILE = "image_file"
|
|
COL_NAME_IMAGE_URL = "image_url"
|
|
COL_NAME_NAME = "name"
|
|
COL_NAME_DESCRIPTION = "description"
|
|
COL_NAME_GEO = "geo"
|
|
COL_NAME_CATEGORY = "category"
|
|
COL_NAME_CAT_NAME = "category_name"
|
|
COL_NAME_LABEL_SOURCE = "label_source"
|
|
COL_NAME_EMB_FILE = "emb_file"
|
|
COL_NAME_MASK_FILE = "mask_file"
|
|
DEFAULT_IMAGE_SIZE = (298, 228)
|
|
|
|
|
|
COUNTRY_CODE_TO_COUNTRY_NAME = {
|
|
"cz": "Czechia",
|
|
"sk": "Slovakia",
|
|
"ro": "Romania",
|
|
"gr": "Greece",
|
|
"si": "Slovenia",
|
|
"hu": "Hungary",
|
|
"hr": "Croatia",
|
|
"es": "Spain",
|
|
"lt": "Lithuania",
|
|
"lv": "Latvia",
|
|
"tr": "Turkey",
|
|
"ee": "Estonia",
|
|
"bg": "Bulgaria",
|
|
}
|
|
|
|
COUNTRY_CODE_TO_COUNTRY_NAME_W_CC = {
|
|
name + f" ({cc})" for cc, name in COUNTRY_CODE_TO_COUNTRY_NAME
|
|
}
|
|
|
|
|
|
def get_dataframe(split_type: str, dataset_dir=GLAMI_DATAPATH, nrows=None):
|
|
assert split_type in ("train", "test")
|
|
if nrows is not None:
|
|
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv", nrows=nrows)
|
|
else:
|
|
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv")
|
|
df[COL_NAME_IMAGE_FILE] = (
|
|
dataset_dir + "/images/" + df[COL_NAME_IMAGE_ID].astype(str) + ".jpg"
|
|
)
|
|
df[COL_NAME_DESCRIPTION] = df[COL_NAME_DESCRIPTION].fillna("")
|
|
assert os.path.exists(df.loc[0, COL_NAME_IMAGE_FILE])
|
|
return df[
|
|
[
|
|
COL_NAME_ITEM_ID,
|
|
COL_NAME_IMAGE_ID,
|
|
COL_NAME_NAME,
|
|
COL_NAME_DESCRIPTION,
|
|
COL_NAME_GEO,
|
|
COL_NAME_CAT_NAME,
|
|
COL_NAME_LABEL_SOURCE,
|
|
COL_NAME_IMAGE_FILE,
|
|
]
|
|
]
|
|
|
|
|
|
class GlamiDataset:
|
|
def __init__(self, dataset_dir, langs=None, labels=None, nrows=None):
|
|
self.dataset_dir = dataset_dir
|
|
self.data_langs = langs
|
|
self.labels = labels
|
|
self.nrows = nrows
|
|
self.multilingual_dataset = {}
|
|
|
|
def num_labels(self):
|
|
return len(self.labels)
|
|
|
|
def langs(self):
|
|
return self.data_langs
|
|
|
|
def get_label_binarizer(self, labels):
|
|
mlb = LabelBinarizer()
|
|
mlb.fit(labels)
|
|
print(
|
|
f"- Label binarizer initialized with the following labels:\n{mlb.classes_}"
|
|
)
|
|
return mlb
|
|
|
|
def binarize_labels(self, labels):
|
|
if hasattr(self, "mlb"):
|
|
return self.mlb.transform(labels)
|
|
else:
|
|
raise ValueError("Label binarizer not initialized")
|
|
|
|
def load_df(self, split, dataset_dir):
|
|
return get_dataframe(split, dataset_dir=dataset_dir, nrows=self.nrows)
|
|
|
|
def build_dataset(self):
|
|
train_dataset = self.load_df("train", self.dataset_dir)
|
|
test_dataset = self.load_df("test", self.dataset_dir)
|
|
|
|
if self.data_langs is None:
|
|
self.data_langs = train_dataset.geo.unique().tolist()
|
|
|
|
if self.labels is None:
|
|
self.labels = train_dataset.category_name.unique().tolist()
|
|
self.mlb = self.get_label_binarizer(self.labels)
|
|
|
|
self.multilingual_dataset = {
|
|
lang: [data_tr, data_te]
|
|
for (lang, data_tr), (_, data_te) in zip(
|
|
train_dataset.groupby("geo"), test_dataset.groupby("geo")
|
|
)
|
|
if lang in self.data_langs
|
|
}
|
|
|
|
def training(self):
|
|
# TODO: tolist() or ???
|
|
lXtr = {
|
|
lang: (df.name + " " + df.description).tolist()
|
|
for lang, (df, _) in self.multilingual_dataset.items()
|
|
}
|
|
lYtr = {
|
|
lang: self.binarize_labels(df.category_name.tolist())
|
|
for lang, (df, _) in self.multilingual_dataset.items()
|
|
}
|
|
return lXtr, lYtr
|
|
|
|
def test(self):
|
|
lXte = {
|
|
lang: (df.name + " " + df.description).tolist()
|
|
for lang, (_, df) in self.multilingual_dataset.items()
|
|
}
|
|
lYte = {
|
|
lang: self.binarize_labels(df.category_name.tolist())
|
|
for lang, (_, df) in self.multilingual_dataset.items()
|
|
}
|
|
return lXte, lYte
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Hello glamiDataset")
|
|
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=None)
|
|
dataset.build_dataset()
|
|
lXtr, lYtr = dataset.training()
|
|
lXte, lYte = dataset.testing()
|
|
exit(0)
|