QuAcc/quacc/method/base.py

281 lines
8.6 KiB
Python
Raw Normal View History

2023-11-08 17:26:44 +01:00
from abc import abstractmethod
from copy import deepcopy
from typing import List
import numpy as np
2023-11-10 01:24:18 +01:00
import scipy.sparse as sp
2023-11-08 17:26:44 +01:00
from quapy.data import LabelledCollection
from quapy.method.aggregative import BaseQuantifier
from sklearn.base import BaseEstimator
import quacc.method.confidence as conf
2023-12-21 16:47:35 +01:00
from quacc.data import (
ExtBinPrev,
ExtendedCollection,
ExtendedData,
ExtendedPrev,
ExtensionPolicy,
ExtMulPrev,
)
2023-11-08 17:26:44 +01:00
class BaseAccuracyEstimator(BaseQuantifier):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
2023-12-21 16:47:35 +01:00
dense=False,
2023-11-08 17:26:44 +01:00
):
self.__check_classifier(classifier)
self.quantifier = quantifier
2023-12-21 16:47:35 +01:00
self.extpol = ExtensionPolicy(dense=dense)
2023-11-08 17:26:44 +01:00
def __check_classifier(self, classifier):
if not hasattr(classifier, "predict_proba"):
raise ValueError(
f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities."
)
self.classifier = classifier
2023-11-10 01:24:18 +01:00
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if pred_proba is None:
pred_proba = self.classifier.predict_proba(coll.X)
2023-11-11 18:46:27 +01:00
return ExtendedCollection.from_lc(
2024-01-30 13:56:17 +01:00
coll, pred_proba=pred_proba, ext=pred_proba, extpol=self.extpol
2023-11-11 18:46:27 +01:00
)
2023-11-10 01:24:18 +01:00
def _extend_instances(self, instances: np.ndarray | sp.csr_matrix):
pred_proba = self.classifier.predict_proba(instances)
2023-11-11 18:46:27 +01:00
return ExtendedData(instances, pred_proba=pred_proba, extpol=self.extpol)
2023-11-10 01:24:18 +01:00
@abstractmethod
def fit(self, train: LabelledCollection | ExtendedCollection):
...
@abstractmethod
2023-12-21 16:47:35 +01:00
def estimate(self, instances, ext=False) -> ExtendedPrev:
2023-11-10 01:24:18 +01:00
...
2023-12-21 16:47:35 +01:00
@property
def dense(self):
return self.extpol.dense
2023-11-10 01:24:18 +01:00
class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
confidence=None,
):
super().__init__(
classifier=classifier,
quantifier=quantifier,
)
2023-11-10 01:24:18 +01:00
self.__check_confidence(confidence)
self.calibrator = None
2023-11-10 01:24:18 +01:00
def __check_confidence(self, confidence):
if isinstance(confidence, str):
self.confidence = [confidence]
elif isinstance(confidence, list):
self.confidence = confidence
else:
self.confidence = None
def _fit_confidence(self, X, y, probas):
self.confidence_metrics = conf.get_metrics(self.confidence)
if self.confidence_metrics is None:
return
for m in self.confidence_metrics:
m.fit(X, y, probas)
def _get_pred_ext(self, pred_proba: np.ndarray):
return pred_proba
def __get_ext(
self, X: np.ndarray | sp.csr_matrix, pred_proba: np.ndarray
) -> np.ndarray:
if self.confidence_metrics is None or len(self.confidence_metrics) == 0:
return pred_proba
_conf_ext = np.concatenate(
[m.conf(X, pred_proba) for m in self.confidence_metrics],
2023-11-10 01:24:18 +01:00
axis=1,
)
2023-11-08 17:26:44 +01:00
_pred_ext = self._get_pred_ext(pred_proba)
return np.concatenate([_conf_ext, _pred_ext], axis=1)
2023-12-21 16:47:35 +01:00
def extend(
self, coll: LabelledCollection, pred_proba=None, prefit=False
) -> ExtendedCollection:
2023-11-08 17:26:44 +01:00
if pred_proba is None:
pred_proba = self.classifier.predict_proba(coll.X)
2023-12-21 16:47:35 +01:00
if prefit:
self._fit_confidence(coll.X, coll.y, pred_proba)
else:
if not hasattr(self, "confidence_metrics"):
raise AttributeError(
"Confidence metrics are not fit and cannot be computed."
"Consider setting prefit to True."
)
_ext = self.__get_ext(coll.X, pred_proba)
2023-11-11 18:46:27 +01:00
return ExtendedCollection.from_lc(
coll, pred_proba=pred_proba, ext=_ext, extpol=self.extpol
)
2023-11-08 17:26:44 +01:00
2023-11-10 01:24:18 +01:00
def _extend_instances(
self,
instances: np.ndarray | sp.csr_matrix,
) -> ExtendedData:
pred_proba = self.classifier.predict_proba(instances)
_ext = self.__get_ext(instances, pred_proba)
2023-11-11 18:46:27 +01:00
return ExtendedData(
instances, pred_proba=pred_proba, ext=_ext, extpol=self.extpol
)
2023-11-08 17:26:44 +01:00
2023-11-10 01:24:18 +01:00
class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
2023-11-08 17:26:44 +01:00
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
confidence: str = None,
2023-11-11 18:46:27 +01:00
collapse_false=False,
2023-12-21 16:47:35 +01:00
group_false=False,
dense=False,
2023-11-08 17:26:44 +01:00
):
super().__init__(
classifier=classifier,
quantifier=quantifier,
confidence=confidence,
)
2023-12-21 16:47:35 +01:00
self.extpol = ExtensionPolicy(
collapse_false=collapse_false,
group_false=group_false,
dense=dense,
)
2023-11-08 17:26:44 +01:00
self.e_train = None
2023-12-21 16:47:35 +01:00
# def _get_pred_ext(self, pred_proba: np.ndarray):
# return np.argmax(pred_proba, axis=1, keepdims=True)
2023-11-08 17:26:44 +01:00
def fit(self, train: LabelledCollection):
pred_proba = self.classifier.predict_proba(train.X)
self._fit_confidence(train.X, train.y, pred_proba)
self.e_train = self.extend(train, pred_proba=pred_proba)
2023-11-08 17:26:44 +01:00
self.quantifier.fit(self.e_train)
return self
2023-11-10 01:24:18 +01:00
def estimate(
self, instances: ExtendedData | np.ndarray | sp.csr_matrix
2023-12-21 16:47:35 +01:00
) -> ExtendedPrev:
2023-11-10 01:24:18 +01:00
e_inst = instances
if not isinstance(e_inst, ExtendedData):
e_inst = self._extend_instances(instances)
2023-11-08 17:26:44 +01:00
2023-11-10 01:24:18 +01:00
estim_prev = self.quantifier.quantify(e_inst.X)
2023-12-21 16:47:35 +01:00
return ExtMulPrev(
estim_prev,
e_inst.nbcl,
q_classes=self.quantifier.classes_,
extpol=self.extpol,
2023-11-11 18:46:27 +01:00
)
2023-11-08 17:26:44 +01:00
2023-11-26 16:32:01 +01:00
@property
def collapse_false(self):
return self.extpol.collapse_false
2023-12-21 16:47:35 +01:00
@property
def group_false(self):
return self.extpol.group_false
2023-11-08 17:26:44 +01:00
2023-11-10 01:24:18 +01:00
class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
2023-11-08 17:26:44 +01:00
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
confidence: str = None,
2023-12-21 16:47:35 +01:00
group_false: bool = False,
dense: bool = False,
2023-11-08 17:26:44 +01:00
):
super().__init__(
classifier=classifier,
quantifier=quantifier,
confidence=confidence,
)
self.quantifiers = []
2023-12-21 16:47:35 +01:00
self.extpol = ExtensionPolicy(
group_false=group_false,
dense=dense,
)
2023-11-08 17:26:44 +01:00
def fit(self, train: LabelledCollection | ExtendedCollection):
pred_proba = self.classifier.predict_proba(train.X)
self._fit_confidence(train.X, train.y, pred_proba)
self.e_train = self.extend(train, pred_proba=pred_proba)
2023-11-08 17:26:44 +01:00
self.n_classes = self.e_train.n_classes
2023-11-10 01:24:18 +01:00
e_trains = self.e_train.split_by_pred()
2023-11-08 17:26:44 +01:00
self.quantifiers = []
2023-11-10 01:24:18 +01:00
for train in e_trains:
2023-11-08 17:26:44 +01:00
quant = deepcopy(self.quantifier)
quant.fit(train)
self.quantifiers.append(quant)
return self
2023-11-10 01:24:18 +01:00
def estimate(
self, instances: ExtendedData | np.ndarray | sp.csr_matrix
) -> np.ndarray:
e_inst = instances
if not isinstance(e_inst, ExtendedData):
e_inst = self._extend_instances(instances)
2023-11-08 17:26:44 +01:00
2023-11-10 01:24:18 +01:00
s_inst = e_inst.split_by_pred()
norms = [s_i.shape[0] / len(e_inst) for s_i in s_inst]
2023-11-08 17:26:44 +01:00
estim_prevs = self._quantify_helper(s_inst, norms)
2023-12-21 16:47:35 +01:00
# estim_prev = np.concatenate(estim_prevs.T)
# return ExtendedPrev(estim_prev, e_inst.nbcl, extpol=self.extpol)
return ExtBinPrev(
estim_prevs,
e_inst.nbcl,
q_classes=[quant.classes_ for quant in self.quantifiers],
extpol=self.extpol,
)
2023-11-08 17:26:44 +01:00
def _quantify_helper(
self,
2023-11-10 01:24:18 +01:00
s_inst: List[np.ndarray | sp.csr_matrix],
2023-11-08 17:26:44 +01:00
norms: List[float],
):
estim_prevs = []
for quant, inst, norm in zip(self.quantifiers, s_inst, norms):
if inst.shape[0] > 0:
estim_prevs.append(quant.quantify(inst) * norm)
else:
2023-12-21 16:47:35 +01:00
estim_prevs.append(np.zeros((len(quant.classes_),)))
2023-11-08 17:26:44 +01:00
2023-12-21 16:47:35 +01:00
# return np.array(estim_prevs)
return estim_prevs
@property
def group_false(self):
return self.extpol.group_false
2023-11-08 17:26:44 +01:00
BAE = BaseAccuracyEstimator
MCAE = MultiClassAccuracyEstimator
BQAE = BinaryQuantifierAccuracyEstimator