diff --git a/quacc/data.py b/quacc/data.py index 3103d74..75cd8b9 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -1,11 +1,9 @@ -import math -from typing import List, Optional +from typing import List, Tuple import numpy as np import scipy.sparse as sp from quapy.data import LabelledCollection - # Extended classes # # 0 ~ True 0 @@ -20,32 +18,54 @@ from quapy.data import LabelledCollection # | False 0 | True 1 | # |__________|__________| # -class ExClassManager: - @staticmethod - def get_ex(n_classes: int, true_class: int, pred_class: int) -> int: - return true_class * n_classes + pred_class - - @staticmethod - def get_pred(n_classes: int, ex_class: int) -> int: - return ex_class % n_classes - - @staticmethod - def get_true(n_classes: int, ex_class: int) -> int: - return ex_class // n_classes -class ExtendedCollection(LabelledCollection): +class ExtendedData: def __init__( self, instances: np.ndarray | sp.csr_matrix, - labels: np.ndarray, - classes: Optional[List] = None, + pred_proba: np.ndarray, + ext: np.ndarray = None, ): - super().__init__(instances, labels, classes=classes) + self.b_instances_ = instances + self.pred_proba_ = pred_proba + self.ext_ = ext + self.instances = self.__extend_instances(instances, pred_proba, ext=ext) - def split_by_pred(self): - _ncl = int(math.sqrt(self.n_classes)) - _indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances) + def __extend_instances( + self, + instances: np.ndarray | sp.csr_matrix, + pred_proba: np.ndarray, + ext: np.ndarray = None, + ) -> np.ndarray | sp.csr_matrix: + to_append = pred_proba + if ext is not None: + to_append = np.concatenate([ext, pred_proba], axis=1) + + if isinstance(instances, sp.csr_matrix): + _to_append = sp.csr_matrix(to_append) + n_x = sp.hstack([instances, _to_append]) + elif isinstance(instances, np.ndarray): + n_x = np.concatenate((instances, to_append), axis=1) + else: + raise ValueError("Unsupported matrix format") + + return n_x + + @property + def X(self): + return self.instances + + def __split_index_by_pred(self) -> List[np.ndarray]: + _pred_label = np.argmax(self.pred_proba_, axis=0) + + return [ + (_pred_label == cl).nonzero()[0] + for cl in np.arange(self.pred_proba_.shape[0]) + ] + + def split_by_pred(self, return_indexes=False): + _indexes = self.__split_index_by_pred() if isinstance(self.instances, np.ndarray): _instances = [ self.instances[ind] if ind.shape[0] > 0 else np.asarray([], dtype=int) @@ -58,93 +78,95 @@ class ExtendedCollection(LabelledCollection): else sp.csr_matrix(np.empty((0, 0), dtype=int)) for ind in _indexes ] - _labels = [ - np.asarray( - [ - ExClassManager.get_true(_ncl, lbl) - for lbl in (self.labels[ind] if len(ind) > 0 else []) - ], - dtype=int, - ) - for ind in _indexes - ] - return [ - ExtendedCollection(inst, lbl, classes=range(0, _ncl)) - for (inst, lbl) in zip(_instances, _labels) - ] - @classmethod - def split_inst_by_pred( - cls, n_classes: int, instances: np.ndarray | sp.csr_matrix - ) -> (List[np.ndarray | sp.csr_matrix], List[float]): - _indexes = cls._split_index_by_pred(n_classes, instances) - if isinstance(instances, np.ndarray): - _instances = [ - instances[ind] if ind.shape[0] > 0 else np.asarray([], dtype=int) - for ind in _indexes - ] - elif isinstance(instances, sp.csr_matrix): - _instances = [ - instances[ind] - if ind.shape[0] > 0 - else sp.csr_matrix(np.empty((0, 0), dtype=int)) - for ind in _indexes - ] - norms = [inst.shape[0] / instances.shape[0] for inst in _instances] - return _instances, norms + if return_indexes: + return _instances, _indexes - @classmethod - def _split_index_by_pred( - cls, n_classes: int, instances: np.ndarray | sp.csr_matrix - ) -> List[np.ndarray]: - if isinstance(instances, np.ndarray): - _pred_label = [np.argmax(inst[-n_classes:], axis=0) for inst in instances] - elif isinstance(instances, sp.csr_matrix): - _pred_label = [ - np.argmax(inst[:, -n_classes:].toarray().flatten(), axis=0) - for inst in instances - ] - else: - raise ValueError("Unsupported matrix format") + return _instances - return [ - np.asarray([j for (j, x) in enumerate(_pred_label) if x == i], dtype=int) - for i in range(0, n_classes) - ] + def __len__(self): + return self.instances.shape[0] - @classmethod - def extend_instances( - cls, instances: np.ndarray | sp.csr_matrix, pred_proba: np.ndarray - ) -> np.ndarray | sp.csr_matrix: - if isinstance(instances, sp.csr_matrix): - _pred_proba = sp.csr_matrix(pred_proba) - n_x = sp.hstack([instances, _pred_proba]) - elif isinstance(instances, np.ndarray): - n_x = np.concatenate((instances, pred_proba), axis=1) - else: - raise ValueError("Unsupported matrix format") - return n_x +class ExtendedLabels: + def __init__(self, true: np.ndarray, pred: np.ndarray, ncl: np.ndarray): + self.true = true + self.pred = pred + self.ncl = ncl - @classmethod - def extend_collection( - cls, - base: LabelledCollection, - pred_proba: np.ndarray, + @property + def y(self): + return self.true * self.ncl + self.pred + + def __getitem__(self, idx): + return ExtendedLabels(self.true[idx], self.pred[idx], self.ncl) + + +class ExtendedCollection(LabelledCollection): + def __init__( + self, + instances: np.ndarray | sp.csr_matrix, + labels: np.ndarray, + pred_proba: np.ndarray = None, + ext: np.ndarray = None, ): - n_classes = base.n_classes + e_data, e_labels, _classes = self.__extend_collection( + instances=instances, + labels=labels, + pred_proba=pred_proba, + ext=ext, + ) + self.e_data_ = e_data + self.e_labels_ = e_labels + super().__init__(e_data.X, e_labels.y, classes=_classes) + @classmethod + def from_lc( + cls, + lc: LabelledCollection, + predict_proba: np.ndarray, + ext: np.ndarray = None, + ): + return ExtendedCollection(lc.X, lc.y, pred_proba=predict_proba, ext=ext) + + @property + def pred_proba(self): + return self.e_data_.pred_proba_ + + @property + def ext(self): + return self.e_data_.ext_ + + @property + def eX(self): + return self.e_data_ + + @property + def ey(self): + return self.e_labels_ + + def split_by_pred(self): + _ncl = len(self.pred_proba) + _instances, _indexes = self.e_data_.split_by_pred(return_indexes=True) + _labels = [self.ey[ind] for ind in _indexes] + return [ + LabelledCollection(inst, lbl.true, classes=range(0, _ncl)) + for inst, lbl in zip(_instances, _labels) + ] + + def __extend_collection( + self, + instances: sp.csr_matrix | np.ndarray, + labels: np.ndarray, + pred_proba: np.ndarray, + ext: np.ndarray = None, + ) -> Tuple[ExtendedData, ExtendedLabels, np.ndarray]: + n_classes = np.unique(labels).shape[0] # n_X = [ X | predicted probs. ] - n_x = cls.extend_instances(base.X, pred_proba) + e_instances = ExtendedData(instances, pred_proba, ext=ext) # n_y = (exptected y, predicted y) - pred_proba = pred_proba[:, -n_classes:] preds = np.argmax(pred_proba, axis=-1) - n_y = np.asarray( - [ - ExClassManager.get_ex(n_classes, true_class, pred_class) - for (true_class, pred_class) in zip(base.y, preds) - ] - ) + e_labels = ExtendedLabels(labels, preds, n_classes) - return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)]) + return e_instances, e_labels, np.arange(n_classes**2) diff --git a/quacc/evaluation/__init__.py b/quacc/evaluation/__init__.py index aa393ad..5ca7300 100644 --- a/quacc/evaluation/__init__.py +++ b/quacc/evaluation/__init__.py @@ -22,7 +22,7 @@ def evaluate( estim_prevs, true_prevs = [], [] for sample in protocol(): e_sample = estimator.extend(sample) - estim_prev = estimator.estimate(e_sample.X, ext=True) + estim_prev = estimator.estimate(e_sample.eX) estim_prevs.append(estim_prev) true_prevs.append(e_sample.prevalence()) diff --git a/quacc/method/base.py b/quacc/method/base.py index 89ed701..36411b7 100644 --- a/quacc/method/base.py +++ b/quacc/method/base.py @@ -1,15 +1,14 @@ -import math from abc import abstractmethod from copy import deepcopy from typing import List import numpy as np +import scipy.sparse as sp from quapy.data import LabelledCollection from quapy.method.aggregative import BaseQuantifier -from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator -from quacc.data import ExtendedCollection +from quacc.data import ExtendedCollection, ExtendedData class BaseAccuracyEstimator(BaseQuantifier): @@ -17,11 +16,9 @@ class BaseAccuracyEstimator(BaseQuantifier): self, classifier: BaseEstimator, quantifier: BaseQuantifier, - confidence=None, ): self.__check_classifier(classifier) self.quantifier = quantifier - self.confidence = confidence def __check_classifier(self, classifier): if not hasattr(classifier, "predict_proba"): @@ -30,6 +27,45 @@ class BaseAccuracyEstimator(BaseQuantifier): ) self.classifier = classifier + def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: + if pred_proba is None: + pred_proba = self.classifier.predict_proba(coll.X) + + return ExtendedCollection.from_lc(coll, pred_proba=pred_proba) + + def _extend_instances(self, instances: np.ndarray | sp.csr_matrix, pred_proba=None): + if pred_proba is None: + pred_proba = self.classifier.predict_proba(instances) + + return ExtendedData(instances, pred_proba=pred_proba) + + @abstractmethod + def fit(self, train: LabelledCollection | ExtendedCollection): + ... + + @abstractmethod + def estimate(self, instances, ext=False) -> np.ndarray: + ... + + +class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator): + def __init__( + self, + classifier: BaseEstimator, + quantifier: BaseQuantifier, + confidence=None, + ): + super().__init__(classifier, quantifier) + self.__check_confidence(confidence) + + def __check_confidence(self, confidence): + if isinstance(confidence, str): + self.confidence = [confidence] + elif isinstance(confidence, list): + self.confidence = confidence + else: + self.confidence = None + def __get_confidence(self): def max_conf(probas): _mc = np.max(probas, axis=-1) @@ -42,47 +78,49 @@ class BaseAccuracyEstimator(BaseQuantifier): return _ent if self.confidence is None: - return None + return [] __confs = { "max_conf": max_conf, "entropy": entropy, } - return __confs.get(self.confidence, None) + return [__confs.get(c, None) for c in self.confidence] - def __get_ext(self, pred_proba): - _ext = pred_proba - _f_conf = self.__get_confidence() - if _f_conf is not None: - _confs = _f_conf(pred_proba).reshape((len(pred_proba), 1)) - _ext = np.concatenate((_confs, pred_proba), axis=1) + def __get_ext(self, pred_proba: np.ndarray) -> np.ndarray: + __confidence = self.__get_confidence() - return _ext + if __confidence is None or len(__confidence) == 0: + return None + + return np.concatenate( + [ + _f_conf(pred_proba).reshape((len(pred_proba), 1)) + for _f_conf in __confidence + if _f_conf is not None + ], + axis=1, + ) def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: if pred_proba is None: pred_proba = self.classifier.predict_proba(coll.X) _ext = self.__get_ext(pred_proba) - return ExtendedCollection.extend_collection(coll, pred_proba=_ext) + return ExtendedCollection.from_lc(coll, pred_proba=pred_proba, ext=_ext) - def _extend_instances(self, instances: np.ndarray | csr_matrix, pred_proba=None): + def _extend_instances( + self, + instances: np.ndarray | sp.csr_matrix, + pred_proba=None, + ) -> ExtendedData: if pred_proba is None: pred_proba = self.classifier.predict_proba(instances) _ext = self.__get_ext(pred_proba) - return ExtendedCollection.extend_instances(instances, _ext) - - @abstractmethod - def fit(self, train: LabelledCollection | ExtendedCollection): - ... - - @abstractmethod - def estimate(self, instances, ext=False) -> np.ndarray: - ... + return ExtendedData(instances, pred_proba=pred_proba, ext=_ext) -class MultiClassAccuracyEstimator(BaseAccuracyEstimator): +class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator): def __init__( self, classifier: BaseEstimator, @@ -103,10 +141,14 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): return self - def estimate(self, instances, ext=False) -> np.ndarray: - e_inst = instances if ext else self._extend_instances(instances) + def estimate( + self, instances: ExtendedData | np.ndarray | sp.csr_matrix + ) -> np.ndarray: + e_inst = instances + if not isinstance(e_inst, ExtendedData): + e_inst = self._extend_instances(instances) - estim_prev = self.quantifier.quantify(e_inst) + estim_prev = self.quantifier.quantify(e_inst.X) return self._check_prevalence_classes(estim_prev, self.quantifier.classes_) def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray: @@ -117,7 +159,7 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): return estim_prev -class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): +class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator): def __init__( self, classifier: BaseEstimator, @@ -130,28 +172,30 @@ class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): confidence=confidence, ) self.quantifiers = [] - self.e_trains = [] def fit(self, train: LabelledCollection | ExtendedCollection): self.e_train = self.extend(train) self.n_classes = self.e_train.n_classes - self.e_trains = self.e_train.split_by_pred() + e_trains = self.e_train.split_by_pred() self.quantifiers = [] - for train in self.e_trains: + for train in e_trains: quant = deepcopy(self.quantifier) quant.fit(train) self.quantifiers.append(quant) return self - def estimate(self, instances, ext=False): - # TODO: test - e_inst = instances if ext else self._extend_instances(instances) + def estimate( + self, instances: ExtendedData | np.ndarray | sp.csr_matrix + ) -> np.ndarray: + e_inst = instances + if not isinstance(e_inst, ExtendedData): + e_inst = self._extend_instances(instances) - _ncl = int(math.sqrt(self.n_classes)) - s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst) + s_inst = e_inst.split_by_pred() + norms = [s_i.shape[0] / len(e_inst) for s_i in s_inst] estim_prevs = self._quantify_helper(s_inst, norms) estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten() @@ -159,7 +203,7 @@ class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): def _quantify_helper( self, - s_inst: List[np.ndarray | csr_matrix], + s_inst: List[np.ndarray | sp.csr_matrix], norms: List[float], ): estim_prevs = [] diff --git a/quacc/method/model_selection.py b/quacc/method/model_selection.py index 4e4df34..f0262c1 100644 --- a/quacc/method/model_selection.py +++ b/quacc/method/model_selection.py @@ -2,8 +2,8 @@ import itertools from copy import deepcopy from time import time from typing import Callable, Union -import numpy as np +import numpy as np import quapy as qp from quapy.data import LabelledCollection from quapy.model_selection import GridSearchQ @@ -12,7 +12,7 @@ from sklearn.base import BaseEstimator import quacc as qc import quacc.error -from quacc.data import ExtendedCollection +from quacc.data import ExtendedCollection, ExtendedData from quacc.evaluation import evaluate from quacc.logger import SubLogger from quacc.method.base import ( @@ -182,7 +182,7 @@ class GridSearchAE(BaseAccuracyEstimator): assert hasattr(self, "best_model_"), "quantify called before fit" return self.best_model().extend(coll, pred_proba=pred_proba) - def estimate(self, instances, ext=False): + def estimate(self, instances): """Estimate class prevalence values using the best model found after calling the :meth:`fit` method. :param instances: sample contanining the instances @@ -191,7 +191,7 @@ class GridSearchAE(BaseAccuracyEstimator): """ assert hasattr(self, "best_model_"), "estimate called before fit" - return self.best_model().estimate(instances, ext=ext) + return self.best_model().estimate(instances) def set_params(self, **parameters): """Sets the hyper-parameters to explore. @@ -220,7 +220,6 @@ class GridSearchAE(BaseAccuracyEstimator): raise ValueError("best_model called before fit") - class MCAEgsq(MultiClassAccuracyEstimator): def __init__( self, @@ -257,10 +256,15 @@ class MCAEgsq(MultiClassAccuracyEstimator): return self - def estimate(self, instances, ext=False) -> np.ndarray: - e_inst = instances if ext else self._extend_instances(instances) - estim_prev = self.quantifier.quantify(e_inst) - return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_) + def estimate(self, instances) -> np.ndarray: + e_inst = instances + if not isinstance(e_inst, ExtendedData): + e_inst = self._extend_instances(instances) + + estim_prev = self.quantifier.quantify(e_inst.X) + return self._check_prevalence_classes( + estim_prev, self.quantifier.best_model().classes_ + ) class BQAEgsq(BinaryQuantifierAccuracyEstimator): diff --git a/tests/test_data.py b/tests/test_data.py index 53c6d30..c5a383f 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,48 +1,8 @@ -import pytest -from quacc.data import ExClassManager as ECM, ExtendedCollection import numpy as np +import pytest import scipy.sparse as sp - -class TestExClassManager: - @pytest.mark.parametrize( - "true_class,pred_class,result", - [ - (0, 0, 0), - (0, 1, 1), - (1, 0, 2), - (1, 1, 3), - ], - ) - def test_get_ex(self, true_class, pred_class, result): - ncl = 2 - assert ECM.get_ex(ncl, true_class, pred_class) == result - - @pytest.mark.parametrize( - "ex_class,result", - [ - (0, 0), - (1, 1), - (2, 0), - (3, 1), - ], - ) - def test_get_pred(self, ex_class, result): - ncl = 2 - assert ECM.get_pred(ncl, ex_class) == result - - @pytest.mark.parametrize( - "ex_class,result", - [ - (0, 0), - (1, 0), - (2, 1), - (3, 1), - ], - ) - def test_get_true(self, ex_class, result): - ncl = 2 - assert ECM.get_true(ncl, ex_class) == result +from quacc.data import ExtendedCollection class TestExtendedCollection: