diff --git a/quacc/data.py b/quacc/data.py new file mode 100644 index 0000000..7d511b8 --- /dev/null +++ b/quacc/data.py @@ -0,0 +1,16 @@ +import numpy as np +import scipy.sparse as sp +from quapy.data import LabelledCollection +from typing import List, Optional + + +class ExtendedCollection(LabelledCollection): + def __init__( + self, + b_coll: LabelledCollection, + instances: np.ndarray | sp.csr_matrix, + labels: np.ndarray, + classes: Optional[List] = None, + ): + super().__init__(instances, labels, classes=classes) + diff --git a/quacc/estimator.py b/quacc/estimator.py new file mode 100644 index 0000000..e0f8520 --- /dev/null +++ b/quacc/estimator.py @@ -0,0 +1,86 @@ +import numpy as np +import scipy.sparse as sp +from quapy.data import LabelledCollection +from quapy.method.base import BaseQuantifier +from sklearn.base import BaseEstimator +from sklearn.model_selection import cross_val_predict + +import quacc as qc +from .data import ExtendedCollection + + +def _check_prevalence_classes(true_classes, estim_classes, estim_prev): + for _cls in true_classes: + if _cls not in estim_classes: + estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0) + return estim_prev + +def _get_ex_class(classes, true_class, pred_class): + return true_class * classes + pred_class + +def _extend_instances(instances, pred_proba): + if isinstance(instances, sp.csr_matrix): + _pred_proba = sp.csr_matrix(pred_proba) + n_x = sp.hstack([instances, _pred_proba]) + elif isinstance(instances, np.ndarray): + n_x = np.concatenate((instances, pred_proba), axis=1) + else: + raise ValueError("Unsupported matrix format") + + return n_x + +def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection: + n_classes = base.n_classes + + # n_X = [ X | predicted probs. ] + n_x = _extend_instances(base.X, pred_proba) + + # n_y = (exptected y, predicted y) + pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba]) + n_y = np.asarray( + [ + _get_ex_class(n_classes, true_class, pred_class) + for (true_class, pred_class) in zip(base.y, pred) + ] + ) + + return ExtendedCollection(n_x, n_y, [*range(0, n_classes * n_classes)]) + + +class AccuracyEstimator: + def __init__(self, model: BaseEstimator, q_model: BaseQuantifier): + self.model = model + self.q_model = q_model + self.e_train = None + + def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection: + if not pred_proba: + pred_proba = self.model.predict_proba(base.X) + return _extend_collection(base, pred_proba) + + def fit(self, train: LabelledCollection | ExtendedCollection): + # check if model is fit + # self.model.fit(*train.Xy) + if isinstance(train, LabelledCollection): + pred_prob_train = cross_val_predict( + self.model, train.Xy, method="predict_proba" + ) + + self.e_train = _extend_collection(train, pred_prob_train) + else: + self.e_train = train + + self.q_model.fit(self.e_train) + + def estimate(self, instances, ext=False): + if not ext: + pred_prob = self.model.predict_proba(instances) + e_inst = _extend_instances(instances, pred_prob) + else: + e_inst = instances + + estim_prev = self.q_model.quantify(e_inst) + + return _check_prevalence_classes( + e_inst.classes_, self.q_model.classes_, estim_prev + ) diff --git a/quacc/evaluation.py b/quacc/evaluation.py new file mode 100644 index 0000000..336093e --- /dev/null +++ b/quacc/evaluation.py @@ -0,0 +1,24 @@ +from quapy.method.base import BaseQuantifier +from quapy.protocol import OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol + +from .estimator import AccuracyEstimator, _extend_collection + + +def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol): + + # ensure that the protocol returns a LabelledCollection for each iteration + protocol.collator = OnLabelledCollectionProtocol.get_collator('labelled_collection') + + base_prevs, true_prevs, estim_prevs = [], [], [] + for sample in protocol(): + e_sample = estimator.extend(sample) + estim_prev = estimator.estimate(e_sample.X, ext=True) + base_prevs.append(sample.prevalence()) + true_prevs.append(e_sample.prevalence()) + estim_prevs.append(estim_prev) + + return base_prevs, true_prevs, estim_prevs + + +def evaluate(): + pass diff --git a/quacc/main.py b/quacc/main.py index 8fd148d..0d2423b 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -95,9 +95,9 @@ def extend_and_quantify( return _test.prevalence(), _estim_prev if isinstance(test, LabelledCollection): - _orig_prev, _true_prev, _estim_prev = quantify_extended(test) + _true_prev, _estim_prev = quantify_extended(test) _errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0]) - return ([_orig_prev], [_true_prev], [_estim_prev], [_errors]) + return ([test.prevalence()], [_true_prev], [_estim_prev], [_errors]) elif isinstance(test, AbstractStochasticSeededProtocol): orig_prevs, true_prevs, estim_prevs, errors = [], [], [], [] diff --git a/quacc/quantifier.py b/quacc/quantifier.py deleted file mode 100644 index 20bc151..0000000 --- a/quacc/quantifier.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np -import scipy.sparse as sp -from quapy.data import LabelledCollection -from quapy.method.base import BaseQuantifier -from sklearn.base import BaseEstimator -from sklearn.model_selection import cross_val_predict - - -def _get_ex_class(classes, true_class, pred_class): - return true_class * classes + pred_class - - -def _extend_collection(coll, pred_prob): - n_classes = coll.n_classes - - # n_X = [ X | predicted probs. ] - if isinstance(coll.X, sp.csr_matrix): - pred_prob_csr = sp.csr_matrix(pred_prob) - n_x = sp.hstack([coll.X, pred_prob_csr]) - elif isinstance(coll.X, np.ndarray): - n_x = np.concatenate((coll.X, pred_prob), axis=1) - else: - raise ValueError("Unsupported matrix format") - - # n_y = (exptected y, predicted y) - n_y = [] - for i, true_class in enumerate(coll.y): - pred_class = pred_prob[i].argmax(axis=0) - n_y.append(_get_ex_class(n_classes, true_class, pred_class)) - - return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)]) - - -class AccuracyQuantifier: - def __init__(self, model: BaseEstimator, q_model: BaseQuantifier): - self.model = model - self.q_model = q_model - - def fit(self, train: LabelledCollection): - self._train = train - self.model.fit(*self._train.Xy) - self._pred_prob_train = cross_val_predict( - self.model, *self._train.Xy, method="predict_proba" - ) - self._e_train = _extend_collection(self._train, self._pred_prob_train) - - self.q_model.fit(self._e_train)