QuAcc/quacc/dataset.py

379 lines
13 KiB
Python

import itertools
import math
import os
import pickle
import tarfile
from typing import List, Tuple
import numpy as np
import quapy as qp
from quapy.data.base import LabelledCollection
from sklearn.conftest import fetch_rcv1
from sklearn.utils import Bunch
from quacc import utils
from quacc.environment import env
TRAIN_VAL_PROP = 0.5
def fetch_cifar10() -> Bunch:
URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
data_home = utils.get_quacc_home()
unzipped_path = data_home / "cifar-10-batches-py"
if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f:
f.extractall(data_home)
os.remove(downloaded_path)
datas = []
data_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("data")])
for f in data_names:
with open(unzipped_path / f, "rb") as file:
datas.append(pickle.load(file, encoding="bytes"))
tests = []
test_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("test")])
for f in test_names:
with open(unzipped_path / f, "rb") as file:
tests.append(pickle.load(file, encoding="bytes"))
with open(unzipped_path / "batches.meta", "rb") as file:
meta = pickle.load(file, encoding="bytes")
return Bunch(
train=Bunch(
data=np.concatenate([d[b"data"] for d in datas], axis=0),
labels=np.concatenate([d[b"labels"] for d in datas]),
),
test=Bunch(
data=np.concatenate([d[b"data"] for d in tests], axis=0),
labels=np.concatenate([d[b"labels"] for d in tests]),
),
label_names=[cs.decode("utf-8") for cs in meta[b"label_names"]],
)
def fetch_cifar100():
URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
data_home = utils.get_quacc_home()
unzipped_path = data_home / "cifar-100-python"
if not unzipped_path.exists():
downloaded_path = data_home / URL.split("/")[-1]
utils.download_file(URL, downloaded_path)
with tarfile.open(downloaded_path) as f:
f.extractall(data_home)
os.remove(downloaded_path)
with open(unzipped_path / "train", "rb") as file:
train_d = pickle.load(file, encoding="bytes")
with open(unzipped_path / "test", "rb") as file:
test_d = pickle.load(file, encoding="bytes")
with open(unzipped_path / "meta", "rb") as file:
meta_d = pickle.load(file, encoding="bytes")
train_bunch = Bunch(
data=train_d[b"data"],
fine_labels=np.array(train_d[b"fine_labels"]),
coarse_labels=np.array(train_d[b"coarse_labels"]),
)
test_bunch = Bunch(
data=test_d[b"data"],
fine_labels=np.array(test_d[b"fine_labels"]),
coarse_labels=np.array(test_d[b"coarse_labels"]),
)
return Bunch(
train=train_bunch,
test=test_bunch,
fine_label_names=meta_d[b"fine_label_names"],
coarse_label_names=meta_d[b"coarse_label_names"],
)
class DatasetSample:
def __init__(
self,
train: LabelledCollection,
validation: LabelledCollection,
test: LabelledCollection,
):
self.train = train
self.validation = validation
self.test = test
@property
def train_prev(self):
return self.train.prevalence()
@property
def validation_prev(self):
return self.validation.prevalence()
@property
def prevs(self):
return {"train": self.train_prev, "validation": self.validation_prev}
class DatasetProvider:
def __spambase(self, **kwargs):
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
# provare min_df=5
def __imdb(self, **kwargs):
return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test
def __rcv1(self, target, **kwargs):
n_train = 23149
available_targets = ["CCAT", "GCAT", "MCAT"]
if target is None or target not in available_targets:
raise ValueError(f"Invalid target {target}")
dataset = fetch_rcv1()
target_index = np.where(dataset.target_names == target)[0]
all_train_d = dataset.data[:n_train, :]
test_d = dataset.data[n_train:, :]
labels = dataset.target[:, target_index].toarray().flatten()
all_train_l, test_l = labels[:n_train], labels[n_train:]
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __cifar10(self, target, **kwargs):
dataset = fetch_cifar10()
available_targets: list = dataset.label_names
if target is None or self._target not in available_targets:
raise ValueError(f"Invalid target {target}")
target_index = available_targets.index(target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.labels == target_index).astype(int)
test_d = dataset.test.data
test_l = (dataset.test.labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __cifar100(self, target, **kwargs):
dataset = fetch_cifar100()
available_targets: list = dataset.coarse_label_names
if target is None or target not in available_targets:
raise ValueError(f"Invalid target {target}")
target_index = available_targets.index(target)
all_train_d = dataset.train.data
all_train_l = (dataset.train.coarse_labels == target_index).astype(int)
test_d = dataset.test.data
test_l = (dataset.test.coarse_labels == target_index).astype(int)
all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
test = LabelledCollection(test_d, test_l, classes=[0, 1])
return all_train, test
def __twitter_gasp(self, **kwargs):
return qp.datasets.fetch_twitter("gasp", min_df=3).train_test
def alltrain_test(
self, name: str, target: str | None
) -> Tuple[LabelledCollection, LabelledCollection]:
all_train, test = {
"spambase": self.__spambase,
"imdb": self.__imdb,
"rcv1": self.__rcv1,
"cifar10": self.__cifar10,
"cifar100": self.__cifar100,
"twitter_gasp": self.__twitter_gasp,
}[name](target=target)
return all_train, test
class Dataset(DatasetProvider):
def __init__(self, name, n_prevalences=9, prevs=None, target=None):
self._name = name
self._target = target
self.all_train, self.test = self.alltrain_test(self._name, self._target)
self.__resample_all_train()
self.prevs = None
self._n_prevs = n_prevalences
self.__check_prevs(prevs)
self.prevs = self.__build_prevs()
def __resample_all_train(self):
tr_counts, tr_ncl = self.all_train.counts(), self.all_train.n_classes
_resample_prevs = np.full((tr_ncl,), fill_value=1.0 / tr_ncl)
self.all_train = self.all_train.sampling(
np.min(tr_counts) * tr_ncl,
*_resample_prevs.tolist(),
random_state=env._R_SEED,
)
def __check_prevs(self, prevs):
try:
iter(prevs)
except TypeError:
return
if prevs is None or len(prevs) == 0:
return
def is_float_iterable(obj):
try:
it = iter(obj)
return all([isinstance(o, float) for o in it])
except TypeError:
return False
if not all([is_float_iterable(p) for p in prevs]):
return
if not all([len(p) == self.all_train.n_classes for p in prevs]):
return
if not all([sum(p) == 1.0 for p in prevs]):
return
self.prevs = np.unique(prevs, axis=0)
def __build_prevs(self):
if self.prevs is not None:
return self.prevs
dim = self.all_train.n_classes
lspace = np.linspace(0.0, 1.0, num=self._n_prevs + 1, endpoint=False)[1:]
mesh = np.array(np.meshgrid(*[lspace for _ in range(dim)])).T.reshape(-1, dim)
mesh = mesh[np.where(mesh.sum(axis=1) == 1.0)]
return np.around(np.unique(mesh, axis=0), decimals=4)
def __build_sample(
self,
p: np.ndarray,
at_size: int,
):
all_train_sampled = self.all_train.sampling(
at_size, *(p[:-1]), random_state=env._R_SEED
)
train, validation = all_train_sampled.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
)
return DatasetSample(train, validation, self.test)
def get(self) -> List[DatasetSample]:
at_size = min(
math.floor(len(self.all_train) * (1.0 / self.all_train.n_classes) / p)
for _prev in self.prevs
for p in _prev
)
return [self.__build_sample(p, at_size) for p in self.prevs]
def __call__(self):
return self.get()
@property
def name(self):
match (self._name, self._n_prevs):
case (("rcv1" | "cifar10" | "cifar100"), 9):
return f"{self._name}_{self._target}"
case (("rcv1" | "cifar10" | "cifar100"), _):
return f"{self._name}_{self._target}_{self._n_prevs}prevs"
case (_, 9):
return f"{self._name}"
case (_, _):
return f"{self._name}_{self._n_prevs}prevs"
@property
def nprevs(self):
return self.prevs.shape[0]
# >>> fetch_rcv1().target_names
# array(['C11', 'C12', 'C13', 'C14', 'C15', 'C151', 'C1511', 'C152', 'C16',
# 'C17', 'C171', 'C172', 'C173', 'C174', 'C18', 'C181', 'C182',
# 'C183', 'C21', 'C22', 'C23', 'C24', 'C31', 'C311', 'C312', 'C313',
# 'C32', 'C33', 'C331', 'C34', 'C41', 'C411', 'C42', 'CCAT', 'E11',
# 'E12', 'E121', 'E13', 'E131', 'E132', 'E14', 'E141', 'E142',
# 'E143', 'E21', 'E211', 'E212', 'E31', 'E311', 'E312', 'E313',
# 'E41', 'E411', 'E51', 'E511', 'E512', 'E513', 'E61', 'E71', 'ECAT',
# 'G15', 'G151', 'G152', 'G153', 'G154', 'G155', 'G156', 'G157',
# 'G158', 'G159', 'GCAT', 'GCRIM', 'GDEF', 'GDIP', 'GDIS', 'GENT',
# 'GENV', 'GFAS', 'GHEA', 'GJOB', 'GMIL', 'GOBIT', 'GODD', 'GPOL',
# 'GPRO', 'GREL', 'GSCI', 'GSPO', 'GTOUR', 'GVIO', 'GVOTE', 'GWEA',
# 'GWELF', 'M11', 'M12', 'M13', 'M131', 'M132', 'M14', 'M141',
# 'M142', 'M143', 'MCAT'], dtype=object)
def rcv1_info():
dataset = fetch_rcv1()
n_train = 23149
targets = []
for target in ["CCAT", "MCAT", "GCAT"]:
target_index = np.where(dataset.target_names == target)[0]
train_t_prev = np.average(
dataset.target[:n_train, target_index].toarray().flatten()
)
test_t_prev = np.average(
dataset.target[n_train:, target_index].toarray().flatten()
)
d = Dataset(name="rcv1", target=target)()[0]
targets.append(
(
target,
{
"train": (1.0 - train_t_prev, train_t_prev),
"test": (1.0 - test_t_prev, test_t_prev),
"train_size": len(d.train),
"val_size": len(d.validation),
"test_size": len(d.test),
},
)
)
for n, d in targets:
print(f"{n}:")
for k, v in d.items():
if isinstance(v, tuple):
print(f"\t{k}: {v[0]:.4f}, {v[1]:.4f}")
else:
print(f"\t{k}: {v}")
def imdb_info():
train, test = qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test
train_t_prev = train.prevalence()
test_t_prev = test.prevalence()
dst = Dataset(name="imdb")()[0]
d = {
"train": (train_t_prev[0], train_t_prev[1]),
"test": (test_t_prev[0], test_t_prev[1]),
"train_size": len(dst.train),
"val_size": len(dst.validation),
"test_size": len(dst.test),
}
print("imdb:")
for k, v in d.items():
if isinstance(v, tuple):
print(f"\t{k}: {v[0]:.4f}, {v[1]:.4f}")
else:
print(f"\t{k}: {v}")
if __name__ == "__main__":
rcv1_info()
imdb_info()