import itertools import math import os import pickle import tarfile from typing import List, Tuple import numpy as np import quapy as qp from quapy.data.base import LabelledCollection from sklearn.conftest import fetch_rcv1 from sklearn.utils import Bunch from quacc import utils from quacc.environment import env TRAIN_VAL_PROP = 0.5 def fetch_cifar10() -> Bunch: URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" data_home = utils.get_quacc_home() unzipped_path = data_home / "cifar-10-batches-py" if not unzipped_path.exists(): downloaded_path = data_home / URL.split("/")[-1] utils.download_file(URL, downloaded_path) with tarfile.open(downloaded_path) as f: f.extractall(data_home) os.remove(downloaded_path) datas = [] data_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("data")]) for f in data_names: with open(unzipped_path / f, "rb") as file: datas.append(pickle.load(file, encoding="bytes")) tests = [] test_names = sorted([f for f in os.listdir(unzipped_path) if f.startswith("test")]) for f in test_names: with open(unzipped_path / f, "rb") as file: tests.append(pickle.load(file, encoding="bytes")) with open(unzipped_path / "batches.meta", "rb") as file: meta = pickle.load(file, encoding="bytes") return Bunch( train=Bunch( data=np.concatenate([d[b"data"] for d in datas], axis=0), labels=np.concatenate([d[b"labels"] for d in datas]), ), test=Bunch( data=np.concatenate([d[b"data"] for d in tests], axis=0), labels=np.concatenate([d[b"labels"] for d in tests]), ), label_names=[cs.decode("utf-8") for cs in meta[b"label_names"]], ) def fetch_cifar100(): URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" data_home = utils.get_quacc_home() unzipped_path = data_home / "cifar-100-python" if not unzipped_path.exists(): downloaded_path = data_home / URL.split("/")[-1] utils.download_file(URL, downloaded_path) with tarfile.open(downloaded_path) as f: f.extractall(data_home) os.remove(downloaded_path) with open(unzipped_path / "train", "rb") as file: train_d = pickle.load(file, encoding="bytes") with open(unzipped_path / "test", "rb") as file: test_d = pickle.load(file, encoding="bytes") with open(unzipped_path / "meta", "rb") as file: meta_d = pickle.load(file, encoding="bytes") train_bunch = Bunch( data=train_d[b"data"], fine_labels=np.array(train_d[b"fine_labels"]), coarse_labels=np.array(train_d[b"coarse_labels"]), ) test_bunch = Bunch( data=test_d[b"data"], fine_labels=np.array(test_d[b"fine_labels"]), coarse_labels=np.array(test_d[b"coarse_labels"]), ) return Bunch( train=train_bunch, test=test_bunch, fine_label_names=meta_d[b"fine_label_names"], coarse_label_names=meta_d[b"coarse_label_names"], ) class DatasetSample: def __init__( self, train: LabelledCollection, validation: LabelledCollection, test: LabelledCollection, ): self.train = train self.validation = validation self.test = test @property def train_prev(self): return self.train.prevalence() @property def validation_prev(self): return self.validation.prevalence() @property def prevs(self): return {"train": self.train_prev, "validation": self.validation_prev} class DatasetProvider: def __spambase(self, **kwargs): return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test # provare min_df=5 def __imdb(self, **kwargs): return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test def __rcv1(self, target, **kwargs): n_train = 23149 available_targets = ["CCAT", "GCAT", "MCAT"] if target is None or target not in available_targets: raise ValueError(f"Invalid target {target}") dataset = fetch_rcv1() target_index = np.where(dataset.target_names == target)[0] all_train_d = dataset.data[:n_train, :] test_d = dataset.data[n_train:, :] labels = dataset.target[:, target_index].toarray().flatten() all_train_l, test_l = labels[:n_train], labels[n_train:] all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) test = LabelledCollection(test_d, test_l, classes=[0, 1]) return all_train, test def __cifar10(self, target, **kwargs): dataset = fetch_cifar10() available_targets: list = dataset.label_names if target is None or self._target not in available_targets: raise ValueError(f"Invalid target {target}") target_index = available_targets.index(target) all_train_d = dataset.train.data all_train_l = (dataset.train.labels == target_index).astype(int) test_d = dataset.test.data test_l = (dataset.test.labels == target_index).astype(int) all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) test = LabelledCollection(test_d, test_l, classes=[0, 1]) return all_train, test def __cifar100(self, target, **kwargs): dataset = fetch_cifar100() available_targets: list = dataset.coarse_label_names if target is None or target not in available_targets: raise ValueError(f"Invalid target {target}") target_index = available_targets.index(target) all_train_d = dataset.train.data all_train_l = (dataset.train.coarse_labels == target_index).astype(int) test_d = dataset.test.data test_l = (dataset.test.coarse_labels == target_index).astype(int) all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) test = LabelledCollection(test_d, test_l, classes=[0, 1]) return all_train, test def __twitter_gasp(self, **kwargs): return qp.datasets.fetch_twitter("gasp", min_df=3).train_test def alltrain_test( self, name: str, target: str | None ) -> Tuple[LabelledCollection, LabelledCollection]: all_train, test = { "spambase": self.__spambase, "imdb": self.__imdb, "rcv1": self.__rcv1, "cifar10": self.__cifar10, "cifar100": self.__cifar100, "twitter_gasp": self.__twitter_gasp, }[name](target=target) return all_train, test class Dataset(DatasetProvider): def __init__(self, name, n_prevalences=9, prevs=None, target=None): self._name = name self._target = target self.all_train, self.test = self.alltrain_test(self._name, self._target) self.__resample_all_train() self.prevs = None self._n_prevs = n_prevalences self.__check_prevs(prevs) self.prevs = self.__build_prevs() def __resample_all_train(self): tr_counts, tr_ncl = self.all_train.counts(), self.all_train.n_classes _resample_prevs = np.full((tr_ncl,), fill_value=1.0 / tr_ncl) self.all_train = self.all_train.sampling( np.min(tr_counts) * tr_ncl, *_resample_prevs.tolist(), random_state=env._R_SEED, ) def __check_prevs(self, prevs): try: iter(prevs) except TypeError: return if prevs is None or len(prevs) == 0: return def is_float_iterable(obj): try: it = iter(obj) return all([isinstance(o, float) for o in it]) except TypeError: return False if not all([is_float_iterable(p) for p in prevs]): return if not all([len(p) == self.all_train.n_classes for p in prevs]): return if not all([sum(p) == 1.0 for p in prevs]): return self.prevs = np.unique(prevs, axis=0) def __build_prevs(self): if self.prevs is not None: return self.prevs dim = self.all_train.n_classes lspace = np.linspace(0.0, 1.0, num=self._n_prevs + 1, endpoint=False)[1:] mesh = np.array(np.meshgrid(*[lspace for _ in range(dim)])).T.reshape(-1, dim) mesh = mesh[np.where(mesh.sum(axis=1) == 1.0)] return np.around(np.unique(mesh, axis=0), decimals=4) def __build_sample( self, p: np.ndarray, at_size: int, ): all_train_sampled = self.all_train.sampling( at_size, *(p[:-1]), random_state=env._R_SEED ) train, validation = all_train_sampled.split_stratified( train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED ) return DatasetSample(train, validation, self.test) def get(self) -> List[DatasetSample]: at_size = min( math.floor(len(self.all_train) * (1.0 / self.all_train.n_classes) / p) for _prev in self.prevs for p in _prev ) return [self.__build_sample(p, at_size) for p in self.prevs] def __call__(self): return self.get() @property def name(self): match (self._name, self._n_prevs): case (("rcv1" | "cifar10" | "cifar100"), 9): return f"{self._name}_{self._target}" case (("rcv1" | "cifar10" | "cifar100"), _): return f"{self._name}_{self._target}_{self._n_prevs}prevs" case (_, 9): return f"{self._name}" case (_, _): return f"{self._name}_{self._n_prevs}prevs" @property def nprevs(self): return self.prevs.shape[0] # >>> fetch_rcv1().target_names # array(['C11', 'C12', 'C13', 'C14', 'C15', 'C151', 'C1511', 'C152', 'C16', # 'C17', 'C171', 'C172', 'C173', 'C174', 'C18', 'C181', 'C182', # 'C183', 'C21', 'C22', 'C23', 'C24', 'C31', 'C311', 'C312', 'C313', # 'C32', 'C33', 'C331', 'C34', 'C41', 'C411', 'C42', 'CCAT', 'E11', # 'E12', 'E121', 'E13', 'E131', 'E132', 'E14', 'E141', 'E142', # 'E143', 'E21', 'E211', 'E212', 'E31', 'E311', 'E312', 'E313', # 'E41', 'E411', 'E51', 'E511', 'E512', 'E513', 'E61', 'E71', 'ECAT', # 'G15', 'G151', 'G152', 'G153', 'G154', 'G155', 'G156', 'G157', # 'G158', 'G159', 'GCAT', 'GCRIM', 'GDEF', 'GDIP', 'GDIS', 'GENT', # 'GENV', 'GFAS', 'GHEA', 'GJOB', 'GMIL', 'GOBIT', 'GODD', 'GPOL', # 'GPRO', 'GREL', 'GSCI', 'GSPO', 'GTOUR', 'GVIO', 'GVOTE', 'GWEA', # 'GWELF', 'M11', 'M12', 'M13', 'M131', 'M132', 'M14', 'M141', # 'M142', 'M143', 'MCAT'], dtype=object) def rcv1_info(): dataset = fetch_rcv1() n_train = 23149 targets = [] for target in range(103): train_t_prev = np.average(dataset.target[:n_train, target].toarray().flatten()) test_t_prev = np.average(dataset.target[n_train:, target].toarray().flatten()) targets.append( ( dataset.target_names[target], { "train": (1.0 - train_t_prev, train_t_prev), "test": (1.0 - test_t_prev, test_t_prev), }, ) ) targets.sort(key=lambda t: t[1]["train"][1]) for n, d in targets: print(f"{n}:") for k, (fp, tp) in d.items(): print(f"\t{k}: {fp:.4f}, {tp:.4f}") if __name__ == "__main__": fetch_cifar100()