forked from moreo/QuaPy
refactoring aggregative quantifiers
This commit is contained in:
parent
29db15ae25
commit
25f1cc29a3
|
@ -59,7 +59,7 @@ class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabi
|
|||
elif isinstance(k, float):
|
||||
if not (0 < k < 1):
|
||||
raise ValueError('wrong value for val_split: the proportion of validation documents must be in (0,1)')
|
||||
return self.fit_cv(X, y)
|
||||
return self.fit_tr_val(X, y)
|
||||
|
||||
def fit_cv(self, X, y):
|
||||
"""
|
||||
|
@ -94,7 +94,7 @@ class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabi
|
|||
self.classifier.fit(Xtr, ytr)
|
||||
posteriors = self.classifier.predict_proba(Xva)
|
||||
nclasses = len(np.unique(yva))
|
||||
self.calibrator = self.calibrator(posteriors, np.eye(nclasses)[yva], posterior_supplied=True)
|
||||
self.calibration_function = self.calibrator(posteriors, np.eye(nclasses)[yva], posterior_supplied=True)
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Union
|
||||
import numpy as np
|
||||
|
@ -19,25 +19,55 @@ from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
|
|||
# Abstract classes
|
||||
# ------------------------------------
|
||||
|
||||
class AggregativeQuantifier(BaseQuantifier):
|
||||
class AggregativeQuantifier(ABC, BaseQuantifier):
|
||||
"""
|
||||
Abstract class for quantification methods that base their estimations on the aggregation of classification
|
||||
results. Aggregative Quantifiers thus implement a :meth:`classify` method and maintain a :attr:`classifier`
|
||||
attribute. Subclasses of this abstract class must implement the method :meth:`aggregate` which computes the
|
||||
aggregation of label predictions. The method :meth:`quantify` comes with a default implementation based on
|
||||
:meth:`classify` and :meth:`aggregate`.
|
||||
results. Aggregative quantifiers implement a pipeline that consists of generating classification predictions
|
||||
and aggregating them. For this reason, the training phase is implemented by :meth:`classification_fit` followed
|
||||
by :meth:`aggregation_fit`, while the testing phase is implemented by :meth:`classify` followed by
|
||||
:meth:`aggregate`. Subclasses of this abstract class must provide implementations for these methods.
|
||||
Aggregative quantifiers also maintain a :attr:`classifier` attribute.
|
||||
|
||||
The method :meth:`fit` comes with a default implementation based on :meth:`classification_fit`
|
||||
and :meth:`aggregation_fit`.
|
||||
|
||||
The method :meth:`quantify` comes with a default implementation based on :meth:`classify`
|
||||
and :meth:`aggregate`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
"""
|
||||
Trains the aggregative quantifier
|
||||
Trains the aggregative quantifier. This comes down to training a classifier and an aggregation function.
|
||||
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param fit_classifier: whether or not to train the learner (default is True). Set to False if the
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:return: self
|
||||
"""
|
||||
classif_predictions = self.classification_fit(data, fit_classifier)
|
||||
self.aggregation_fit(classif_predictions)
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
"""
|
||||
Trains the classifier if requested (`fit_classifier=True`) and generate the necessary predictions to
|
||||
train the aggregation function.
|
||||
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
"""
|
||||
Trains the aggregation function.
|
||||
|
||||
:param classif_predictions: typically an `ndarray` containing the label predictions, but could be a
|
||||
tuple containing any information needed for fitting the aggregation function
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
|
@ -101,7 +131,7 @@ class AggregativeQuantifier(BaseQuantifier):
|
|||
return self.classifier.classes_
|
||||
|
||||
|
||||
class AggregativeProbabilisticQuantifier(AggregativeQuantifier):
|
||||
class AggregativeProbabilisticQuantifier(AggregativeQuantifier, ABC):
|
||||
"""
|
||||
Abstract class for quantification methods that base their estimations on the aggregation of posterior probabilities
|
||||
as returned by a probabilistic classifier. Aggregative Probabilistic Quantifiers thus extend Aggregative
|
||||
|
@ -227,9 +257,9 @@ class CC(AggregativeQuantifier):
|
|||
def __init__(self, classifier: BaseEstimator):
|
||||
self.classifier = classifier
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
"""
|
||||
Trains the Classify & Count method unless `fit_classifier` is False, in which case, the classifier is assumed to
|
||||
Trains the classifier unless `fit_classifier` is False, in which case, the classifier is assumed to
|
||||
be already fit and there is nothing else to do.
|
||||
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
|
@ -237,7 +267,15 @@ class CC(AggregativeQuantifier):
|
|||
:return: self
|
||||
"""
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier)
|
||||
return self
|
||||
return None
|
||||
|
||||
def aggregation_fit(self, classif_predictions: np.ndarray):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: this is actually None
|
||||
"""
|
||||
pass
|
||||
|
||||
def aggregate(self, classif_predictions: np.ndarray):
|
||||
"""
|
||||
|
@ -269,9 +307,10 @@ class ACC(AggregativeQuantifier):
|
|||
self.val_split = val_split
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
"""
|
||||
Trains a ACC quantifier.
|
||||
Trains the classifier and generates, optionally through a cross-validation procedure, the predictions
|
||||
needed for estimating the misclassification rates matrix.
|
||||
|
||||
:param data: the training set
|
||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
|
@ -281,18 +320,24 @@ class ACC(AggregativeQuantifier):
|
|||
cross validation to estimate the parameters
|
||||
:return: self
|
||||
"""
|
||||
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self.classifier, y, y_, classes, class_count = cross_generate_predictions(
|
||||
self.classifier, true_labels, pred_labels, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=False, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
self.cc = CC(self.classifier)
|
||||
self.Pte_cond_estim_ = self.getPteCondEstim(self.classifier.classes_, y, y_)
|
||||
return (true_labels, pred_labels)
|
||||
|
||||
return self
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: this is actually None
|
||||
"""
|
||||
true_labels, pred_labels = classif_predictions
|
||||
self.cc = CC(self.classifier)
|
||||
self.Pte_cond_estim_ = self.getPteCondEstim(self.classifier.classes_, true_labels, pred_labels)
|
||||
|
||||
@classmethod
|
||||
def getPteCondEstim(cls, classes, y, y_):
|
||||
|
@ -348,10 +393,18 @@ class PCC(AggregativeProbabilisticQuantifier):
|
|||
def __init__(self, classifier: BaseEstimator):
|
||||
self.classifier = classifier
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier, ensure_probabilistic=True)
|
||||
return self
|
||||
|
||||
def aggregation_fit(self, classif_predictions: np.ndarray):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: this is actually None
|
||||
"""
|
||||
pass
|
||||
|
||||
def aggregate(self, classif_posteriors):
|
||||
return F.prevalence_from_probabilities(classif_posteriors, binarize=False)
|
||||
|
||||
|
@ -376,30 +429,37 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
self.val_split = val_split
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
"""
|
||||
Trains a PACC quantifier.
|
||||
Trains the soft classifier and generates, optionally through a cross-validation procedure, the posterior
|
||||
probabilities needed for estimating the misclassification rates matrix.
|
||||
|
||||
:param data: the training set
|
||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param val_split: either a float in (0,1) indicating the proportion of training instances to use for
|
||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a LabelledCollection
|
||||
indicating the validation set itself, or an int indicating the number k of folds to be used in kFCV
|
||||
to estimate the parameters
|
||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a LabelledCollection
|
||||
indicating the validation set itself, or an int indicating the number `k` of folds to be used in `k`-fold
|
||||
cross validation to estimate the parameters
|
||||
:return: self
|
||||
"""
|
||||
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self.classifier, y, y_, classes, class_count = cross_generate_predictions(
|
||||
self.classifier, true_labels, posteriors, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
self.pcc = PCC(self.classifier)
|
||||
self.Pte_cond_estim_ = self.getPteCondEstim(classes, y, y_)
|
||||
return (true_labels, posteriors)
|
||||
|
||||
return self
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: this is actually None
|
||||
"""
|
||||
true_labels, posteriors = classif_predictions
|
||||
self.pcc = PCC(self.classifier)
|
||||
self.Pte_cond_estim_ = self.getPteCondEstim(self.classifier.classes_, true_labels, posteriors)
|
||||
|
||||
@classmethod
|
||||
def getPteCondEstim(cls, classes, y, y_):
|
||||
|
@ -449,7 +509,13 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
|||
self.exact_train_prev = exact_train_prev
|
||||
self.recalib = recalib
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
self.classifier, true_labels, posteriors, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
return (true_labels, posteriors)
|
||||
|
||||
if self.recalib is not None:
|
||||
if self.recalib == 'nbvs':
|
||||
self.classifier = NBVSCalibration(self.non_calibrated)
|
||||
|
@ -477,7 +543,15 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
|||
nfolds=3,
|
||||
random_state=0
|
||||
)
|
||||
return self
|
||||
return None
|
||||
|
||||
def aggregation_fit(self, classif_predictions: np.ndarray):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: this is actually None
|
||||
"""
|
||||
pass
|
||||
|
||||
def aggregate(self, classif_posteriors, epsilon=EPSILON):
|
||||
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||
|
@ -768,7 +842,7 @@ class DMy(AggregativeProbabilisticQuantifier):
|
|||
distributions = np.cumsum(distributions, axis=1)
|
||||
return distributions
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
||||
"""
|
||||
Trains the classifier (if requested) and generates the validation distributions out of the training data.
|
||||
The validation distributions have shape `(n, ch, nbins)`, with `n` the number of classes, `ch` the number of
|
||||
|
@ -787,15 +861,19 @@ class DMy(AggregativeProbabilisticQuantifier):
|
|||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self.classifier, y, posteriors, classes, class_count = cross_generate_predictions(
|
||||
self.classifier, true_labels, posteriors, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
self.validation_distribution = np.asarray(
|
||||
[self.__get_distributions(posteriors[y==cat]) for cat in range(data.n_classes)]
|
||||
)
|
||||
return (true_labels, posteriors)
|
||||
|
||||
return self
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
true_labels, posteriors = classif_predictions
|
||||
n_classes = len(self.classifier.classes_)
|
||||
|
||||
self.validation_distribution = np.asarray(
|
||||
[self.__get_distributions(posteriors[true_labels == cat]) for cat in range(n_classes)]
|
||||
)
|
||||
|
||||
def aggregate(self, posteriors: np.ndarray):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue