cifar10 and cifar100 datasets added
This commit is contained in:
parent
beea23db14
commit
0bee49ccc3
150
quacc/dataset.py
150
quacc/dataset.py
|
@ -1,14 +1,99 @@
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tarfile
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from quapy.data.base import LabelledCollection
|
from quapy.data.base import LabelledCollection
|
||||||
from sklearn.conftest import fetch_rcv1
|
from sklearn.conftest import fetch_rcv1
|
||||||
|
from sklearn.utils import Bunch
|
||||||
|
|
||||||
|
from quacc import utils
|
||||||
|
|
||||||
TRAIN_VAL_PROP = 0.5
|
TRAIN_VAL_PROP = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_cifar10() -> Bunch:
|
||||||
|
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||||
|
data_home = utils.get_quacc_home()
|
||||||
|
unzipped_path = data_home / "cifar-10-batches-py"
|
||||||
|
if not unzipped_path.exists():
|
||||||
|
downloaded_path = data_home / URL.split("/")[-1]
|
||||||
|
utils.download_file(URL, downloaded_path)
|
||||||
|
with tarfile.open(downloaded_path) as f:
|
||||||
|
f.extractall(data_home)
|
||||||
|
os.remove(downloaded_path)
|
||||||
|
|
||||||
|
datas = []
|
||||||
|
data_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("data")])
|
||||||
|
for f in data_names:
|
||||||
|
with open(unzipped_path / f, "rb") as file:
|
||||||
|
datas.append(pickle.load(file, encoding="bytes"))
|
||||||
|
|
||||||
|
tests = []
|
||||||
|
test_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("test")])
|
||||||
|
for f in test_names:
|
||||||
|
with open(unzipped_path / f, "rb") as file:
|
||||||
|
tests.append(pickle.load(file, encoding="bytes"))
|
||||||
|
|
||||||
|
with open(unzipped_path / "batches.meta", "rb") as file:
|
||||||
|
meta = pickle.load(file, encoding="bytes")
|
||||||
|
|
||||||
|
return Bunch(
|
||||||
|
train=Bunch(
|
||||||
|
data=np.concatenate([d[b"data"] for d in datas], axis=0),
|
||||||
|
labels=np.concatenate([d[b"labels"] for d in datas]),
|
||||||
|
),
|
||||||
|
test=Bunch(
|
||||||
|
data=np.concatenate([d[b"data"] for d in tests], axis=0),
|
||||||
|
labels=np.concatenate([d[b"labels"] for d in tests]),
|
||||||
|
),
|
||||||
|
label_names=[cs.decode("utf-8") for cs in meta[b"label_names"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_cifar100():
|
||||||
|
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||||
|
data_home = utils.get_quacc_home()
|
||||||
|
unzipped_path = data_home / "cifar-100-python"
|
||||||
|
if not unzipped_path.exists():
|
||||||
|
downloaded_path = data_home / URL.split("/")[-1]
|
||||||
|
utils.download_file(URL, downloaded_path)
|
||||||
|
with tarfile.open(downloaded_path) as f:
|
||||||
|
f.extractall(data_home)
|
||||||
|
os.remove(downloaded_path)
|
||||||
|
|
||||||
|
with open(unzipped_path / "train", "rb") as file:
|
||||||
|
train_d = pickle.load(file, encoding="bytes")
|
||||||
|
|
||||||
|
with open(unzipped_path / "test", "rb") as file:
|
||||||
|
test_d = pickle.load(file, encoding="bytes")
|
||||||
|
|
||||||
|
with open(unzipped_path / "meta", "rb") as file:
|
||||||
|
meta_d = pickle.load(file, encoding="bytes")
|
||||||
|
|
||||||
|
train_bunch = Bunch(
|
||||||
|
data=train_d[b"data"],
|
||||||
|
fine_labels=np.array(train_d[b"fine_labels"]),
|
||||||
|
coarse_labels=np.array(train_d[b"coarse_labels"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_bunch = Bunch(
|
||||||
|
data=test_d[b"data"],
|
||||||
|
fine_labels=np.array(test_d[b"fine_labels"]),
|
||||||
|
coarse_labels=np.array(test_d[b"coarse_labels"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return Bunch(
|
||||||
|
train=train_bunch,
|
||||||
|
test=test_bunch,
|
||||||
|
fine_label_names=meta_d[b"fine_label_names"],
|
||||||
|
coarse_label_names=meta_d[b"coarse_label_names"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DatasetSample:
|
class DatasetSample:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -71,13 +156,54 @@ class Dataset:
|
||||||
|
|
||||||
return all_train, test
|
return all_train, test
|
||||||
|
|
||||||
def get_raw(self) -> DatasetSample:
|
def __cifar10(self):
|
||||||
|
dataset = fetch_cifar10()
|
||||||
|
available_targets: list = dataset.label_names
|
||||||
|
|
||||||
|
if self._target is None or self._target not in available_targets:
|
||||||
|
raise ValueError(f"Invalid target {self._target}")
|
||||||
|
|
||||||
|
target_index = available_targets.index(self._target)
|
||||||
|
all_train_d = dataset.train.data
|
||||||
|
all_train_l = (dataset.train.labels == target_index).astype(int)
|
||||||
|
test_d = dataset.test.data
|
||||||
|
test_l = (dataset.test.labels == target_index).astype(int)
|
||||||
|
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
|
||||||
|
test = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||||
|
|
||||||
|
return all_train, test
|
||||||
|
|
||||||
|
def __cifar100(self):
|
||||||
|
dataset = fetch_cifar100()
|
||||||
|
available_targets: list = dataset.coarse_label_names
|
||||||
|
|
||||||
|
if self._target is None or self._target not in available_targets:
|
||||||
|
raise ValueError(f"Invalid target {self._target}")
|
||||||
|
|
||||||
|
target_index = available_targets.index(self._target)
|
||||||
|
all_train_d = dataset.train.data
|
||||||
|
all_train_l = (dataset.train.coarse_labels == target_index).astype(int)
|
||||||
|
test_d = dataset.test.data
|
||||||
|
test_l = (dataset.test.coarse_labels == target_index).astype(int)
|
||||||
|
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
|
||||||
|
test = LabelledCollection(test_d, test_l, classes=[0, 1])
|
||||||
|
|
||||||
|
return all_train, test
|
||||||
|
|
||||||
|
def __train_test(self):
|
||||||
all_train, test = {
|
all_train, test = {
|
||||||
"spambase": self.__spambase,
|
"spambase": self.__spambase,
|
||||||
"imdb": self.__imdb,
|
"imdb": self.__imdb,
|
||||||
"rcv1": self.__rcv1,
|
"rcv1": self.__rcv1,
|
||||||
|
"cifar10": self.__cifar10,
|
||||||
|
"cifar100": self.__cifar100,
|
||||||
}[self._name]()
|
}[self._name]()
|
||||||
|
|
||||||
|
return all_train, test
|
||||||
|
|
||||||
|
def get_raw(self) -> DatasetSample:
|
||||||
|
all_train, test = self.__train_test()
|
||||||
|
|
||||||
train, val = all_train.split_stratified(
|
train, val = all_train.split_stratified(
|
||||||
train_prop=TRAIN_VAL_PROP, random_state=0
|
train_prop=TRAIN_VAL_PROP, random_state=0
|
||||||
)
|
)
|
||||||
|
@ -85,11 +211,7 @@ class Dataset:
|
||||||
return DatasetSample(train, val, test)
|
return DatasetSample(train, val, test)
|
||||||
|
|
||||||
def get(self) -> List[DatasetSample]:
|
def get(self) -> List[DatasetSample]:
|
||||||
(all_train, test) = {
|
all_train, test = self.__train_test()
|
||||||
"spambase": self.__spambase,
|
|
||||||
"imdb": self.__imdb,
|
|
||||||
"rcv1": self.__rcv1,
|
|
||||||
}[self._name]()
|
|
||||||
|
|
||||||
# resample all_train set to have (0.5, 0.5) prevalence
|
# resample all_train set to have (0.5, 0.5) prevalence
|
||||||
at_positives = np.sum(all_train.y)
|
at_positives = np.sum(all_train.y)
|
||||||
|
@ -119,11 +241,15 @@ class Dataset:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return (
|
match (self._name, self.n_prevs):
|
||||||
f"{self._name}_{self._target}_{self.n_prevs}prevs"
|
case (("rcv1" | "cifar10" | "cifar100"), 9):
|
||||||
if self._name == "rcv1"
|
return f"{self._name}_{self._target}"
|
||||||
else f"{self._name}_{self.n_prevs}prevs"
|
case (("rcv1" | "cifar10" | "cifar100"), _):
|
||||||
)
|
return f"{self._name}_{self._target}_{self.n_prevs}prevs"
|
||||||
|
case (_, 9):
|
||||||
|
return f"{self._name}"
|
||||||
|
case (_, _):
|
||||||
|
return f"{self._name}_{self.n_prevs}prevs"
|
||||||
|
|
||||||
|
|
||||||
# >>> fetch_rcv1().target_names
|
# >>> fetch_rcv1().target_names
|
||||||
|
@ -168,4 +294,4 @@ def rcv1_info():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
rcv1_info()
|
fetch_cifar100()
|
||||||
|
|
Loading…
Reference in New Issue