cifar10 and cifar100 datasets added

This commit is contained in:
Lorenzo Volpi 2023-11-22 19:24:59 +01:00
parent beea23db14
commit 0bee49ccc3
1 changed files with 138 additions and 12 deletions

View File

@ -1,14 +1,99 @@
import math import math
import os
import pickle
import tarfile
from typing import List from typing import List
import numpy as np import numpy as np
import quapy as qp import quapy as qp
from quapy.data.base import LabelledCollection from quapy.data.base import LabelledCollection
from sklearn.conftest import fetch_rcv1 from sklearn.conftest import fetch_rcv1
from sklearn.utils import Bunch
from quacc import utils
TRAIN_VAL_PROP = 0.5 TRAIN_VAL_PROP = 0.5
def fetch_cifar10() -> Bunch:
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
data_home = utils.get_quacc_home()
unzipped_path = data_home / "cifar-10-batches-py"
if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f:
f.extractall(data_home)
os.remove(downloaded_path)
datas = []
data_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("data")])
for f in data_names:
with open(unzipped_path / f, "rb") as file:
datas.append(pickle.load(file, encoding="bytes"))
tests = []
test_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("test")])
for f in test_names:
with open(unzipped_path / f, "rb") as file:
tests.append(pickle.load(file, encoding="bytes"))
with open(unzipped_path / "batches.meta", "rb") as file:
meta = pickle.load(file, encoding="bytes")
return Bunch(
train=Bunch(
data=np.concatenate([d[b"data"] for d in datas], axis=0),
labels=np.concatenate([d[b"labels"] for d in datas]),
),
test=Bunch(
data=np.concatenate([d[b"data"] for d in tests], axis=0),
labels=np.concatenate([d[b"labels"] for d in tests]),
),
label_names=[cs.decode("utf-8") for cs in meta[b"label_names"]],
)
def fetch_cifar100():
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
data_home = utils.get_quacc_home()
unzipped_path = data_home / "cifar-100-python"
if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f:
f.extractall(data_home)
os.remove(downloaded_path)
with open(unzipped_path / "train", "rb") as file:
train_d = pickle.load(file, encoding="bytes")
with open(unzipped_path / "test", "rb") as file:
test_d = pickle.load(file, encoding="bytes")
with open(unzipped_path / "meta", "rb") as file:
meta_d = pickle.load(file, encoding="bytes")
train_bunch = Bunch(
data=train_d[b"data"],
fine_labels=np.array(train_d[b"fine_labels"]),
coarse_labels=np.array(train_d[b"coarse_labels"]),
)
test_bunch = Bunch(
data=test_d[b"data"],
fine_labels=np.array(test_d[b"fine_labels"]),
coarse_labels=np.array(test_d[b"coarse_labels"]),
)
return Bunch(
train=train_bunch,
test=test_bunch,
fine_label_names=meta_d[b"fine_label_names"],
coarse_label_names=meta_d[b"coarse_label_names"],
)
class DatasetSample: class DatasetSample:
def __init__( def __init__(
self, self,
@ -71,13 +156,54 @@ class Dataset:
return all_train, test return all_train, test
def get_raw(self) -> DatasetSample: def __cifar10(self):
dataset = fetch_cifar10()
available_targets: list = dataset.label_names
if self._target is None or self._target not in available_targets:
raise ValueError(f"Invalid target {self._target}")
target_index = available_targets.index(self._target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.labels == target_index).astype(int)
test_d = dataset.test.data
test_l = (dataset.test.labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __cifar100(self):
dataset = fetch_cifar100()
available_targets: list = dataset.coarse_label_names
if self._target is None or self._target not in available_targets:
raise ValueError(f"Invalid target {self._target}")
target_index = available_targets.index(self._target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.coarse_labels == target_index).astype(int)
test_d = dataset.test.data
test_l = (dataset.test.coarse_labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __train_test(self):
all_train, test = { all_train, test = {
"spambase": self.__spambase, "spambase": self.__spambase,
"imdb": self.__imdb, "imdb": self.__imdb,
"rcv1": self.__rcv1, "rcv1": self.__rcv1,
"cifar10": self.__cifar10,
"cifar100": self.__cifar100,
}[self._name]() }[self._name]()
return all_train, test
def get_raw(self) -> DatasetSample:
all_train, test = self.__train_test()
train, val = all_train.split_stratified( train, val = all_train.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0 train_prop=TRAIN_VAL_PROP, random_state=0
) )
@ -85,11 +211,7 @@ class Dataset:
return DatasetSample(train, val, test) return DatasetSample(train, val, test)
def get(self) -> List[DatasetSample]: def get(self) -> List[DatasetSample]:
(all_train, test) = { all_train, test = self.__train_test()
"spambase": self.__spambase,
"imdb": self.__imdb,
"rcv1": self.__rcv1,
}[self._name]()
# resample all_train set to have (0.5, 0.5) prevalence # resample all_train set to have (0.5, 0.5) prevalence
at_positives = np.sum(all_train.y) at_positives = np.sum(all_train.y)
@ -119,11 +241,15 @@ class Dataset:
@property @property
def name(self): def name(self):
return ( match (self._name, self.n_prevs):
f"{self._name}_{self._target}_{self.n_prevs}prevs" case (("rcv1" | "cifar10" | "cifar100"), 9):
if self._name == "rcv1" return f"{self._name}_{self._target}"
else f"{self._name}_{self.n_prevs}prevs" case (("rcv1" | "cifar10" | "cifar100"), _):
) return f"{self._name}_{self._target}_{self.n_prevs}prevs"
case (_, 9):
return f"{self._name}"
case (_, _):
return f"{self._name}_{self.n_prevs}prevs"
# >>> fetch_rcv1().target_names # >>> fetch_rcv1().target_names
@ -168,4 +294,4 @@ def rcv1_info():
if __name__ == "__main__": if __name__ == "__main__":
rcv1_info() fetch_cifar100()