Data Classes for GLAMI-1M Dataset
This commit is contained in:
parent
7ed98346a5
commit
298f31669d
|
@ -0,0 +1,175 @@
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import zipfile
|
||||||
|
from tempfile import TemporaryFile
|
||||||
|
from typing import BinaryIO, Optional, Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||||
|
|
||||||
|
# From https://github.com/glami/glami-1m/blob/main/load_dataset.py
|
||||||
|
|
||||||
|
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
||||||
|
|
||||||
|
# DATASET_URL = os.environ.get(
|
||||||
|
# "DATASET_URL",
|
||||||
|
# "https://zenodo.org/record/7326406/files/GLAMI-1M-dataset.zip?download=1",
|
||||||
|
# )
|
||||||
|
# EXTRACT_DIR = os.environ.get("EXTRACT_DIR", "/tmp/GLAMI-1M")
|
||||||
|
# DATASET_SUBDIR = "GLAMI-1M-dataset"
|
||||||
|
# DATASET_DIR = dataset_dir = EXTRACT_DIR + "/" + DATASET_SUBDIR
|
||||||
|
# MODEL_DIR = os.environ.get("MODEL_DIR", "/tmp/GLAMI-1M/models")
|
||||||
|
# EMBS_DIR = EXTRACT_DIR + "/embs"
|
||||||
|
# CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-visual"
|
||||||
|
# CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-textual"
|
||||||
|
# # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-visual"
|
||||||
|
# # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-textual"
|
||||||
|
# CLIP_EN_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-en-textual"
|
||||||
|
# GENERATED_DIR = EXTRACT_DIR + "/generated_images"
|
||||||
|
|
||||||
|
COL_NAME_ITEM_ID = "item_id"
|
||||||
|
COL_NAME_IMAGE_ID = "image_id"
|
||||||
|
COL_NAME_IMAGE_FILE = "image_file"
|
||||||
|
COL_NAME_IMAGE_URL = "image_url"
|
||||||
|
COL_NAME_NAME = "name"
|
||||||
|
COL_NAME_DESCRIPTION = "description"
|
||||||
|
COL_NAME_GEO = "geo"
|
||||||
|
COL_NAME_CATEGORY = "category"
|
||||||
|
COL_NAME_CAT_NAME = "category_name"
|
||||||
|
COL_NAME_LABEL_SOURCE = "label_source"
|
||||||
|
COL_NAME_EMB_FILE = "emb_file"
|
||||||
|
COL_NAME_MASK_FILE = "mask_file"
|
||||||
|
DEFAULT_IMAGE_SIZE = (298, 228)
|
||||||
|
|
||||||
|
|
||||||
|
COUNTRY_CODE_TO_COUNTRY_NAME = {
|
||||||
|
"cz": "Czechia",
|
||||||
|
"sk": "Slovakia",
|
||||||
|
"ro": "Romania",
|
||||||
|
"gr": "Greece",
|
||||||
|
"si": "Slovenia",
|
||||||
|
"hu": "Hungary",
|
||||||
|
"hr": "Croatia",
|
||||||
|
"es": "Spain",
|
||||||
|
"lt": "Lithuania",
|
||||||
|
"lv": "Latvia",
|
||||||
|
"tr": "Turkey",
|
||||||
|
"ee": "Estonia",
|
||||||
|
"bg": "Bulgaria",
|
||||||
|
}
|
||||||
|
|
||||||
|
COUNTRY_CODE_TO_COUNTRY_NAME_W_CC = {
|
||||||
|
name + f" ({cc})" for cc, name in COUNTRY_CODE_TO_COUNTRY_NAME
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataframe(split_type: str, dataset_dir=GLAMI_DATAPATH, nrows=None):
|
||||||
|
assert split_type in ("train", "test")
|
||||||
|
if nrows is not None:
|
||||||
|
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv", nrows=nrows)
|
||||||
|
else:
|
||||||
|
df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv")
|
||||||
|
df[COL_NAME_IMAGE_FILE] = (
|
||||||
|
dataset_dir + "/images/" + df[COL_NAME_IMAGE_ID].astype(str) + ".jpg"
|
||||||
|
)
|
||||||
|
df[COL_NAME_DESCRIPTION] = df[COL_NAME_DESCRIPTION].fillna("")
|
||||||
|
assert os.path.exists(df.loc[0, COL_NAME_IMAGE_FILE])
|
||||||
|
return df[
|
||||||
|
[
|
||||||
|
COL_NAME_ITEM_ID,
|
||||||
|
COL_NAME_IMAGE_ID,
|
||||||
|
COL_NAME_NAME,
|
||||||
|
COL_NAME_DESCRIPTION,
|
||||||
|
COL_NAME_GEO,
|
||||||
|
COL_NAME_CAT_NAME,
|
||||||
|
COL_NAME_LABEL_SOURCE,
|
||||||
|
COL_NAME_IMAGE_FILE,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GlamiDataset:
|
||||||
|
def __init__(self, dataset_dir, langs=None, labels=None, nrows=None):
|
||||||
|
self.dataset_dir = dataset_dir
|
||||||
|
self.data_langs = langs
|
||||||
|
self.labels = labels
|
||||||
|
self.nrows = nrows
|
||||||
|
self.multilingual_dataset = {}
|
||||||
|
|
||||||
|
def num_labels(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
def langs(self):
|
||||||
|
return self.data_langs
|
||||||
|
|
||||||
|
def get_label_binarizer(self, labels):
|
||||||
|
mlb = LabelBinarizer()
|
||||||
|
mlb.fit(labels)
|
||||||
|
print(
|
||||||
|
f"- Label binarizer initialized with the following labels:\n{mlb.classes_}"
|
||||||
|
)
|
||||||
|
return mlb
|
||||||
|
|
||||||
|
def binarize_labels(self, labels):
|
||||||
|
if hasattr(self, "mlb"):
|
||||||
|
return self.mlb.transform(labels)
|
||||||
|
else:
|
||||||
|
raise ValueError("Label binarizer not initialized")
|
||||||
|
|
||||||
|
def load_df(self, split, dataset_dir):
|
||||||
|
return get_dataframe(split, dataset_dir=dataset_dir, nrows=self.nrows)
|
||||||
|
|
||||||
|
def build_dataset(self):
|
||||||
|
train_dataset = self.load_df("train", self.dataset_dir)
|
||||||
|
test_dataset = self.load_df("test", self.dataset_dir)
|
||||||
|
|
||||||
|
if self.data_langs is None:
|
||||||
|
self.data_langs = train_dataset.geo.unique().tolist()
|
||||||
|
|
||||||
|
if self.labels is None:
|
||||||
|
self.labels = train_dataset.category_name.unique().tolist()
|
||||||
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
|
self.multilingual_dataset = {
|
||||||
|
lang: [data_tr, data_te]
|
||||||
|
for (lang, data_tr), (_, data_te) in zip(
|
||||||
|
train_dataset.groupby("geo"), test_dataset.groupby("geo")
|
||||||
|
)
|
||||||
|
if lang in self.data_langs
|
||||||
|
}
|
||||||
|
|
||||||
|
def training(self):
|
||||||
|
# TODO: tolist() or ???
|
||||||
|
lXtr = {
|
||||||
|
lang: (df.name + " " + df.description).tolist()
|
||||||
|
for lang, (df, _) in self.multilingual_dataset.items()
|
||||||
|
}
|
||||||
|
lYtr = {
|
||||||
|
lang: self.binarize_labels(df.category_name.tolist())
|
||||||
|
for lang, (df, _) in self.multilingual_dataset.items()
|
||||||
|
}
|
||||||
|
return lXtr, lYtr
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
lXte = {
|
||||||
|
lang: (df.name + " " + df.description).tolist()
|
||||||
|
for lang, (_, df) in self.multilingual_dataset.items()
|
||||||
|
}
|
||||||
|
lYte = {
|
||||||
|
lang: self.binarize_labels(df.category_name.tolist())
|
||||||
|
for lang, (_, df) in self.multilingual_dataset.items()
|
||||||
|
}
|
||||||
|
return lXte, lYte
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Hello glamiDataset")
|
||||||
|
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=None)
|
||||||
|
dataset.build_dataset()
|
||||||
|
lXtr, lYtr = dataset.training()
|
||||||
|
lXte, lYte = dataset.testing()
|
||||||
|
exit(0)
|
22
main.py
22
main.py
|
@ -6,6 +6,7 @@ from time import time
|
||||||
from dataManager.amazonDataset import AmazonDataset
|
from dataManager.amazonDataset import AmazonDataset
|
||||||
from dataManager.multilingualDatset import MultilingualDataset
|
from dataManager.multilingualDatset import MultilingualDataset
|
||||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||||
|
from dataManager.glamiDataset import GlamiDataset
|
||||||
from evaluation.evaluate import evaluate, log_eval
|
from evaluation.evaluate import evaluate, log_eval
|
||||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
|
@ -20,7 +21,12 @@ TODO:
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(datasetname):
|
def get_dataset(datasetname):
|
||||||
assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
|
assert datasetname in [
|
||||||
|
"multinews",
|
||||||
|
"amazon",
|
||||||
|
"rcv1-2",
|
||||||
|
"glami",
|
||||||
|
], "dataset not supported"
|
||||||
|
|
||||||
RCV_DATAPATH = expanduser(
|
RCV_DATAPATH = expanduser(
|
||||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||||
|
@ -30,6 +36,8 @@ def get_dataset(datasetname):
|
||||||
)
|
)
|
||||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||||
|
|
||||||
|
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||||
|
|
||||||
if datasetname == "multinews":
|
if datasetname == "multinews":
|
||||||
dataset = MultiNewsDataset(
|
dataset = MultiNewsDataset(
|
||||||
expanduser(MULTINEWS_DATAPATH),
|
expanduser(MULTINEWS_DATAPATH),
|
||||||
|
@ -46,6 +54,9 @@ def get_dataset(datasetname):
|
||||||
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
|
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
|
||||||
if args.nrows is not None:
|
if args.nrows is not None:
|
||||||
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
||||||
|
elif datasetname == "glami":
|
||||||
|
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows)
|
||||||
|
dataset.build_dataset()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -53,8 +64,10 @@ def get_dataset(datasetname):
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = get_dataset(args.dataset)
|
dataset = get_dataset(args.dataset)
|
||||||
if isinstance(dataset, MultilingualDataset) or isinstance(
|
if (
|
||||||
dataset, MultiNewsDataset
|
isinstance(dataset, MultilingualDataset)
|
||||||
|
or isinstance(dataset, MultiNewsDataset)
|
||||||
|
or isinstance(dataset, GlamiDataset)
|
||||||
):
|
):
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
lX_te, lY_te = dataset.test()
|
lX_te, lY_te = dataset.test()
|
||||||
|
@ -109,7 +122,7 @@ def main(args):
|
||||||
# gfun.get_config()
|
# gfun.get_config()
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
if args.load_trained is None:
|
if args.load_trained is None and not args.nosave:
|
||||||
gfun.save(save_first_tier=True, save_meta=True)
|
gfun.save(save_first_tier=True, save_meta=True)
|
||||||
|
|
||||||
preds = gfun.transform(lX)
|
preds = gfun.transform(lX)
|
||||||
|
@ -131,6 +144,7 @@ if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
||||||
parser.add_argument("--meta", action="store_true")
|
parser.add_argument("--meta", action="store_true")
|
||||||
|
parser.add_argument("--nosave", action="store_true")
|
||||||
# Dataset parameters -------------------
|
# Dataset parameters -------------------
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
||||||
parser.add_argument("--domains", type=str, default="all")
|
parser.add_argument("--domains", type=str, default="all")
|
||||||
|
|
Loading…
Reference in New Issue