Data Classes for GLAMI-1M Dataset
This commit is contained in:
parent
7ed98346a5
commit
298f31669d
|
@ -0,0 +1,175 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from tempfile import TemporaryFile
|
||||
from typing import BinaryIO, Optional, Dict
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
|
||||
# From https://github.com/glami/glami-1m/blob/main/load_dataset.py
|
||||
|
||||
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
|
||||
# DATASET_URL = os.environ.get(
|
||||
# "DATASET_URL",
|
||||
# "https://zenodo.org/record/7326406/files/GLAMI-1M-dataset.zip?download=1",
|
||||
# )
|
||||
# EXTRACT_DIR = os.environ.get("EXTRACT_DIR", "/tmp/GLAMI-1M")
|
||||
# DATASET_SUBDIR = "GLAMI-1M-dataset"
|
||||
# DATASET_DIR = dataset_dir = EXTRACT_DIR + "/" + DATASET_SUBDIR
|
||||
# MODEL_DIR = os.environ.get("MODEL_DIR", "/tmp/GLAMI-1M/models")
|
||||
# EMBS_DIR = EXTRACT_DIR + "/embs"
|
||||
# CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-visual"
|
||||
# CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-textual"
|
||||
# # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-visual"
|
||||
# # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-textual"
|
||||
# CLIP_EN_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-en-textual"
|
||||
# GENERATED_DIR = EXTRACT_DIR + "/generated_images"
|
||||
|
||||
COL_NAME_ITEM_ID = "item_id"
|
||||
COL_NAME_IMAGE_ID = "image_id"
|
||||
COL_NAME_IMAGE_FILE = "image_file"
|
||||
COL_NAME_IMAGE_URL = "image_url"
|
||||
COL_NAME_NAME = "name"
|
||||
COL_NAME_DESCRIPTION = "description"
|
||||
COL_NAME_GEO = "geo"
|
||||
COL_NAME_CATEGORY = "category"
|
||||
COL_NAME_CAT_NAME = "category_name"
|
||||
COL_NAME_LABEL_SOURCE = "label_source"
|
||||
COL_NAME_EMB_FILE = "emb_file"
|
||||
COL_NAME_MASK_FILE = "mask_file"
|
||||
DEFAULT_IMAGE_SIZE = (298, 228)
|
||||
|
||||
|
||||
COUNTRY_CODE_TO_COUNTRY_NAME = {
|
||||
"cz": "Czechia",
|
||||
"sk": "Slovakia",
|
||||
"ro": "Romania",
|
||||
"gr": "Greece",
|
||||
"si": "Slovenia",
|
||||
"hu": "Hungary",
|
||||
"hr": "Croatia",
|
||||
"es": "Spain",
|
||||
"lt": "Lithuania",
|
||||
"lv": "Latvia",
|
||||
"tr": "Turkey",
|
||||
"ee": "Estonia",
|
||||
"bg": "Bulgaria",
|
||||
}
|
||||
|
||||
COUNTRY_CODE_TO_COUNTRY_NAME_W_CC = {
|
||||
name + f" ({cc})" for cc, name in COUNTRY_CODE_TO_COUNTRY_NAME
|
||||
}
|
||||
|
||||
|
||||
def get_dataframe(split_type: str, dataset_dir=GLAMI_DATAPATH, nrows=None):
|
||||
assert split_type in ("train", "test")
|
||||
if nrows is not None:
|
||||
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv", nrows=nrows)
|
||||
else:
|
||||
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv")
|
||||
df[COL_NAME_IMAGE_FILE] = (
|
||||
dataset_dir + "/images/" + df[COL_NAME_IMAGE_ID].astype(str) + ".jpg"
|
||||
)
|
||||
df[COL_NAME_DESCRIPTION] = df[COL_NAME_DESCRIPTION].fillna("")
|
||||
assert os.path.exists(df.loc[0, COL_NAME_IMAGE_FILE])
|
||||
return df[
|
||||
[
|
||||
COL_NAME_ITEM_ID,
|
||||
COL_NAME_IMAGE_ID,
|
||||
COL_NAME_NAME,
|
||||
COL_NAME_DESCRIPTION,
|
||||
COL_NAME_GEO,
|
||||
COL_NAME_CAT_NAME,
|
||||
COL_NAME_LABEL_SOURCE,
|
||||
COL_NAME_IMAGE_FILE,
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class GlamiDataset:
|
||||
def __init__(self, dataset_dir, langs=None, labels=None, nrows=None):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.data_langs = langs
|
||||
self.labels = labels
|
||||
self.nrows = nrows
|
||||
self.multilingual_dataset = {}
|
||||
|
||||
def num_labels(self):
|
||||
return len(self.labels)
|
||||
|
||||
def langs(self):
|
||||
return self.data_langs
|
||||
|
||||
def get_label_binarizer(self, labels):
|
||||
mlb = LabelBinarizer()
|
||||
mlb.fit(labels)
|
||||
print(
|
||||
f"- Label binarizer initialized with the following labels:\n{mlb.classes_}"
|
||||
)
|
||||
return mlb
|
||||
|
||||
def binarize_labels(self, labels):
|
||||
if hasattr(self, "mlb"):
|
||||
return self.mlb.transform(labels)
|
||||
else:
|
||||
raise ValueError("Label binarizer not initialized")
|
||||
|
||||
def load_df(self, split, dataset_dir):
|
||||
return get_dataframe(split, dataset_dir=dataset_dir, nrows=self.nrows)
|
||||
|
||||
def build_dataset(self):
|
||||
train_dataset = self.load_df("train", self.dataset_dir)
|
||||
test_dataset = self.load_df("test", self.dataset_dir)
|
||||
|
||||
if self.data_langs is None:
|
||||
self.data_langs = train_dataset.geo.unique().tolist()
|
||||
|
||||
if self.labels is None:
|
||||
self.labels = train_dataset.category_name.unique().tolist()
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
self.multilingual_dataset = {
|
||||
lang: [data_tr, data_te]
|
||||
for (lang, data_tr), (_, data_te) in zip(
|
||||
train_dataset.groupby("geo"), test_dataset.groupby("geo")
|
||||
)
|
||||
if lang in self.data_langs
|
||||
}
|
||||
|
||||
def training(self):
|
||||
# TODO: tolist() or ???
|
||||
lXtr = {
|
||||
lang: (df.name + " " + df.description).tolist()
|
||||
for lang, (df, _) in self.multilingual_dataset.items()
|
||||
}
|
||||
lYtr = {
|
||||
lang: self.binarize_labels(df.category_name.tolist())
|
||||
for lang, (df, _) in self.multilingual_dataset.items()
|
||||
}
|
||||
return lXtr, lYtr
|
||||
|
||||
def test(self):
|
||||
lXte = {
|
||||
lang: (df.name + " " + df.description).tolist()
|
||||
for lang, (_, df) in self.multilingual_dataset.items()
|
||||
}
|
||||
lYte = {
|
||||
lang: self.binarize_labels(df.category_name.tolist())
|
||||
for lang, (_, df) in self.multilingual_dataset.items()
|
||||
}
|
||||
return lXte, lYte
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Hello glamiDataset")
|
||||
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=None)
|
||||
dataset.build_dataset()
|
||||
lXtr, lYtr = dataset.training()
|
||||
lXte, lYte = dataset.testing()
|
||||
exit(0)
|
22
main.py
22
main.py
|
@ -6,6 +6,7 @@ from time import time
|
|||
from dataManager.amazonDataset import AmazonDataset
|
||||
from dataManager.multilingualDatset import MultilingualDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.glamiDataset import GlamiDataset
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
|
||||
|
@ -20,7 +21,12 @@ TODO:
|
|||
|
||||
|
||||
def get_dataset(datasetname):
|
||||
assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
|
||||
assert datasetname in [
|
||||
"multinews",
|
||||
"amazon",
|
||||
"rcv1-2",
|
||||
"glami",
|
||||
], "dataset not supported"
|
||||
|
||||
RCV_DATAPATH = expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
|
@ -30,6 +36,8 @@ def get_dataset(datasetname):
|
|||
)
|
||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
|
||||
if datasetname == "multinews":
|
||||
dataset = MultiNewsDataset(
|
||||
expanduser(MULTINEWS_DATAPATH),
|
||||
|
@ -46,6 +54,9 @@ def get_dataset(datasetname):
|
|||
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
|
||||
if args.nrows is not None:
|
||||
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
||||
elif datasetname == "glami":
|
||||
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows)
|
||||
dataset.build_dataset()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
||||
|
@ -53,8 +64,10 @@ def get_dataset(datasetname):
|
|||
|
||||
def main(args):
|
||||
dataset = get_dataset(args.dataset)
|
||||
if isinstance(dataset, MultilingualDataset) or isinstance(
|
||||
dataset, MultiNewsDataset
|
||||
if (
|
||||
isinstance(dataset, MultilingualDataset)
|
||||
or isinstance(dataset, MultiNewsDataset)
|
||||
or isinstance(dataset, GlamiDataset)
|
||||
):
|
||||
lX, lY = dataset.training()
|
||||
lX_te, lY_te = dataset.test()
|
||||
|
@ -109,7 +122,7 @@ def main(args):
|
|||
# gfun.get_config()
|
||||
gfun.fit(lX, lY)
|
||||
|
||||
if args.load_trained is None:
|
||||
if args.load_trained is None and not args.nosave:
|
||||
gfun.save(save_first_tier=True, save_meta=True)
|
||||
|
||||
preds = gfun.transform(lX)
|
||||
|
@ -131,6 +144,7 @@ if __name__ == "__main__":
|
|||
parser = ArgumentParser()
|
||||
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
||||
parser.add_argument("--meta", action="store_true")
|
||||
parser.add_argument("--nosave", action="store_true")
|
||||
# Dataset parameters -------------------
|
||||
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
||||
parser.add_argument("--domains", type=str, default="all")
|
||||
|
|
Loading…
Reference in New Issue