cifar10 and cifar100 datasets added

This commit is contained in:
Lorenzo Volpi 2023-11-22 19:24:59 +01:00
parent beea23db14
commit 0bee49ccc3
1 changed files with 138 additions and 12 deletions

View File

@ -1,14 +1,99 @@
import math
import os
import pickle
import tarfile
from typing import List
import numpy as np
import quapy as qp
from quapy.data.base import LabelledCollection
from sklearn.conftest import fetch_rcv1
from sklearn.utils import Bunch
from quacc import utils
TRAIN_VAL_PROP = 0.5
def fetch_cifar10() -> Bunch:
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
data_home = utils.get_quacc_home()
unzipped_path = data_home / "cifar-10-batches-py"
if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f:
f.extractall(data_home)
os.remove(downloaded_path)
datas = []
data_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("data")])
for f in data_names:
with open(unzipped_path / f, "rb") as file:
datas.append(pickle.load(file, encoding="bytes"))
tests = []
test_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("test")])
for f in test_names:
with open(unzipped_path / f, "rb") as file:
tests.append(pickle.load(file, encoding="bytes"))
with open(unzipped_path / "batches.meta", "rb") as file:
meta = pickle.load(file, encoding="bytes")
return Bunch(
train=Bunch(
data=np.concatenate([d[b"data"] for d in datas], axis=0),
labels=np.concatenate([d[b"labels"] for d in datas]),
),
test=Bunch(
data=np.concatenate([d[b"data"] for d in tests], axis=0),
labels=np.concatenate([d[b"labels"] for d in tests]),
),
label_names=[cs.decode("utf-8") for cs in meta[b"label_names"]],
)
def fetch_cifar100():
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
data_home = utils.get_quacc_home()
unzipped_path = data_home / "cifar-100-python"
if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f:
f.extractall(data_home)
os.remove(downloaded_path)
with open(unzipped_path / "train", "rb") as file:
train_d = pickle.load(file, encoding="bytes")
with open(unzipped_path / "test", "rb") as file:
test_d = pickle.load(file, encoding="bytes")
with open(unzipped_path / "meta", "rb") as file:
meta_d = pickle.load(file, encoding="bytes")
train_bunch = Bunch(
data=train_d[b"data"],
fine_labels=np.array(train_d[b"fine_labels"]),
coarse_labels=np.array(train_d[b"coarse_labels"]),
)
test_bunch = Bunch(
data=test_d[b"data"],
fine_labels=np.array(test_d[b"fine_labels"]),
coarse_labels=np.array(test_d[b"coarse_labels"]),
)
return Bunch(
train=train_bunch,
test=test_bunch,
fine_label_names=meta_d[b"fine_label_names"],
coarse_label_names=meta_d[b"coarse_label_names"],
)
class DatasetSample:
def __init__(
self,
@ -71,13 +156,54 @@ class Dataset:
return all_train, test
def get_raw(self) -> DatasetSample:
def __cifar10(self):
dataset = fetch_cifar10()
available_targets: list = dataset.label_names
if self._target is None or self._target not in available_targets:
raise ValueError(f"Invalid target {self._target}")
target_index = available_targets.index(self._target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.labels == target_index).astype(int)
test_d = dataset.test.data
test_l = (dataset.test.labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __cifar100(self):
dataset = fetch_cifar100()
available_targets: list = dataset.coarse_label_names
if self._target is None or self._target not in available_targets:
raise ValueError(f"Invalid target {self._target}")
target_index = available_targets.index(self._target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.coarse_labels == target_index).astype(int)
test_d = dataset.test.data
test_l = (dataset.test.coarse_labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __train_test(self):
all_train, test = {
"spambase": self.__spambase,
"imdb": self.__imdb,
"rcv1": self.__rcv1,
"cifar10": self.__cifar10,
"cifar100": self.__cifar100,
}[self._name]()
return all_train, test
def get_raw(self) -> DatasetSample:
all_train, test = self.__train_test()
train, val = all_train.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0
)
@ -85,11 +211,7 @@ class Dataset:
return DatasetSample(train, val, test)
def get(self) -> List[DatasetSample]:
(all_train, test) = {
"spambase": self.__spambase,
"imdb": self.__imdb,
"rcv1": self.__rcv1,
}[self._name]()
all_train, test = self.__train_test()
# resample all_train set to have (0.5, 0.5) prevalence
at_positives = np.sum(all_train.y)
@ -119,11 +241,15 @@ class Dataset:
@property
def name(self):
return (
f"{self._name}_{self._target}_{self.n_prevs}prevs"
if self._name == "rcv1"
else f"{self._name}_{self.n_prevs}prevs"
)
match (self._name, self.n_prevs):
case (("rcv1" | "cifar10" | "cifar100"), 9):
return f"{self._name}_{self._target}"
case (("rcv1" | "cifar10" | "cifar100"), _):
return f"{self._name}_{self._target}_{self.n_prevs}prevs"
case (_, 9):
return f"{self._name}"
case (_, _):
return f"{self._name}_{self.n_prevs}prevs"
# >>> fetch_rcv1().target_names
@ -168,4 +294,4 @@ def rcv1_info():
if __name__ == "__main__":
rcv1_info()
fetch_cifar100()