refactoring the aggregative quantifiers
This commit is contained in:
parent
25f1cc29a3
commit
0a6185d908
|
@ -44,12 +44,11 @@ class AggregativeQuantifier(ABC, BaseQuantifier):
|
|||
learner has been trained outside the quantifier.
|
||||
:return: self
|
||||
"""
|
||||
classif_predictions = self.classification_fit(data, fit_classifier)
|
||||
classif_predictions = self.classifier_fit_predict(data, fit_classifier)
|
||||
self.aggregation_fit(classif_predictions)
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True, predict_on=None):
|
||||
"""
|
||||
Trains the classifier if requested (`fit_classifier=True`) and generate the necessary predictions to
|
||||
train the aggregation function.
|
||||
|
@ -57,11 +56,62 @@ class AggregativeQuantifier(ABC, BaseQuantifier):
|
|||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param predict_on: specifies the set on which predictions need to be issued. This parameter can
|
||||
be specified as None (default) to indicate no prediction is needed; a float in (0, 1) to
|
||||
indicate the proportion of instances to be used for predictions (the remainder is used for
|
||||
training); an integer >1 to indicate that the predictions must be generated via k-fold
|
||||
cross-validation, using this integer as k; or the data sample itself on which to generate
|
||||
the predictions.
|
||||
"""
|
||||
...
|
||||
assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean'
|
||||
|
||||
self.__check_classifier()
|
||||
|
||||
if predict_on is None:
|
||||
if fit_classifier:
|
||||
self.classifier.fit(*data.Xy)
|
||||
predictions = None
|
||||
|
||||
elif isinstance(predict_on, float):
|
||||
if fit_classifier:
|
||||
if not (0. < predict_on < 1.):
|
||||
raise ValueError(f'proportion {predict_on=} out of range, must be in (0,1)')
|
||||
train, val = data.split_stratified(train_prop=(1 - predict_on))
|
||||
self.classifier.fit(*train.Xy)
|
||||
predictions = (self.classify(val.X), val.y)
|
||||
else:
|
||||
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
|
||||
f'the set on which predictions have to be issued must be '
|
||||
f'explicitly indicated')
|
||||
|
||||
elif isinstance(predict_on, LabelledCollection):
|
||||
if fit_classifier:
|
||||
self.classifier.fit(*data.Xy)
|
||||
predictions = (self.classify(predict_on.X), predict_on.y)
|
||||
|
||||
elif isinstance(predict_on, int):
|
||||
if fit_classifier:
|
||||
if not predict_on > 1:
|
||||
raise ValueError(f'invalid value {predict_on} in fit. '
|
||||
f'Specify a integer >1 for kFCV estimation.')
|
||||
predictions = cross_val_predict(
|
||||
classifier, *data.Xy, cv=predict_on, n_jobs=self.n_jobs, method=self.__classifier_method())
|
||||
self.classifier.fit(*data.Xy)
|
||||
else:
|
||||
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
|
||||
f'the set on which predictions have to be issued must be '
|
||||
f'explicitly indicated')
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f'error: param "predict_on" ({type(predict_on)}) not understood; '
|
||||
f'use either a float indicating the split proportion, or a '
|
||||
f'tuple (X,y) indicating the validation partition')
|
||||
|
||||
return predictions
|
||||
|
||||
@abstractmethod
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
||||
"""
|
||||
Trains the aggregation function.
|
||||
|
||||
|
@ -99,6 +149,13 @@ class AggregativeQuantifier(ABC, BaseQuantifier):
|
|||
"""
|
||||
return self.classifier.predict(instances)
|
||||
|
||||
@property
|
||||
def __classifier_method(self):
|
||||
return 'predict'
|
||||
|
||||
def __check_classifier(self, adapt_if_necessary=False):
|
||||
assert hasattr(self.classifier, 'predict')
|
||||
|
||||
def quantify(self, instances):
|
||||
"""
|
||||
Generate class prevalence estimates for the sample's instances by aggregating the label predictions generated
|
||||
|
@ -142,106 +199,20 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier, ABC):
|
|||
def classify(self, instances):
|
||||
return self.classifier.predict_proba(instances)
|
||||
|
||||
@property
|
||||
def __classifier_method(self):
|
||||
return 'predict_proba'
|
||||
|
||||
# Helper
|
||||
# ------------------------------------
|
||||
def _ensure_probabilistic(classifier):
|
||||
if not hasattr(classifier, 'predict_proba'):
|
||||
print(f'The learner {classifier.__class__.__name__} does not seem to be probabilistic. '
|
||||
f'The learner will be calibrated.')
|
||||
classifier = CalibratedClassifierCV(classifier, cv=5)
|
||||
return classifier
|
||||
|
||||
|
||||
def _training_helper(classifier,
|
||||
data: LabelledCollection,
|
||||
fit_classifier: bool = True,
|
||||
ensure_probabilistic=False,
|
||||
val_split: Union[LabelledCollection, float] = None):
|
||||
"""
|
||||
Training procedure common to all Aggregative Quantifiers.
|
||||
|
||||
:param classifier: the learner to be fit
|
||||
:param data: the data on which to fit the learner. If requested, the data will be split before fitting the learner.
|
||||
:param fit_classifier: whether or not to fit the learner (if False, then bypasses any action)
|
||||
:param ensure_probabilistic: if True, guarantees that the resulting classifier implements predict_proba (if the
|
||||
learner is not probabilistic, then a CalibratedCV instance of it is trained)
|
||||
:param val_split: if specified as a float, indicates the proportion of training instances that will define the
|
||||
validation split (e.g., 0.3 for using 30% of the training set as validation data); if specified as a
|
||||
LabelledCollection, represents the validation split itself
|
||||
:return: the learner trained on the training set, and the unused data (a _LabelledCollection_ if train_val_split>0
|
||||
or None otherwise) to be used as a validation set for any subsequent parameter fitting
|
||||
"""
|
||||
if fit_classifier:
|
||||
if ensure_probabilistic:
|
||||
classifier = _ensure_probabilistic(classifier)
|
||||
if val_split is not None:
|
||||
if isinstance(val_split, float):
|
||||
if not (0 < val_split < 1):
|
||||
raise ValueError(f'train/val split {val_split} out of range, must be in (0,1)')
|
||||
train, unused = data.split_stratified(train_prop=1 - val_split)
|
||||
elif isinstance(val_split, LabelledCollection):
|
||||
train = data
|
||||
unused = val_split
|
||||
def __check_classifier(self, adapt_if_necessary=False):
|
||||
if not hasattr(self.classifier, 'predict_proba'):
|
||||
if adapt_if_necessary:
|
||||
print(f'warning: The learner {self.classifier.__class__.__name__} does not seem to be '
|
||||
f'probabilistic. The learner will be calibrated (using CalibratedClassifierCV).')
|
||||
self.classifier = CalibratedClassifierCV(self.classifier, cv=5)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'param "val_split" ({type(val_split)}) not understood; use either a float indicating the split '
|
||||
'proportion, or a LabelledCollection indicating the validation split')
|
||||
else:
|
||||
train, unused = data, None
|
||||
|
||||
if isinstance(classifier, BaseQuantifier):
|
||||
classifier.fit(train)
|
||||
else:
|
||||
classifier.fit(*train.Xy)
|
||||
else:
|
||||
if ensure_probabilistic:
|
||||
if not hasattr(classifier, 'predict_proba'):
|
||||
raise AssertionError('error: the learner cannot be calibrated since fit_classifier is set to False')
|
||||
unused = None
|
||||
if isinstance(val_split, LabelledCollection):
|
||||
unused = val_split
|
||||
|
||||
return classifier, unused
|
||||
|
||||
|
||||
def cross_generate_predictions(
|
||||
data,
|
||||
classifier,
|
||||
val_split,
|
||||
probabilistic,
|
||||
fit_classifier,
|
||||
n_jobs
|
||||
):
|
||||
|
||||
n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
if isinstance(val_split, int):
|
||||
assert fit_classifier == True, \
|
||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_classifier=False'
|
||||
|
||||
if probabilistic:
|
||||
classifier = _ensure_probabilistic(classifier)
|
||||
predict = 'predict_proba'
|
||||
else:
|
||||
predict = 'predict'
|
||||
y_pred = cross_val_predict(classifier, *data.Xy, cv=val_split, n_jobs=n_jobs, method=predict)
|
||||
class_count = data.counts()
|
||||
|
||||
# fit the learner on all data
|
||||
classifier.fit(*data.Xy)
|
||||
y = data.y
|
||||
classes = data.classes_
|
||||
else:
|
||||
classifier, val_data = _training_helper(
|
||||
classifier, data, fit_classifier, ensure_probabilistic=probabilistic, val_split=val_split
|
||||
)
|
||||
y_pred = classifier.predict_proba(val_data.instances) if probabilistic else classifier.predict(val_data.instances)
|
||||
y = val_data.labels
|
||||
classes = val_data.classes_
|
||||
class_count = val_data.counts()
|
||||
|
||||
return classifier, y, y_pred, classes, class_count
|
||||
raise AssertionError(f'error: The learner {self.classifier.__class__.__name__} does not '
|
||||
f'seem to be probabilistic. The learner cannot be calibrated since '
|
||||
f'fit_classifier is set to False')
|
||||
|
||||
|
||||
# Methods
|
||||
|
@ -257,19 +228,7 @@ class CC(AggregativeQuantifier):
|
|||
def __init__(self, classifier: BaseEstimator):
|
||||
self.classifier = classifier
|
||||
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
"""
|
||||
Trains the classifier unless `fit_classifier` is False, in which case, the classifier is assumed to
|
||||
be already fit and there is nothing else to do.
|
||||
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param fit_classifier: if False, the classifier is assumed to be fit
|
||||
:return: self
|
||||
"""
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier)
|
||||
return None
|
||||
|
||||
def aggregation_fit(self, classif_predictions: np.ndarray):
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
|
@ -307,33 +266,11 @@ class ACC(AggregativeQuantifier):
|
|||
self.val_split = val_split
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
||||
"""
|
||||
Trains the classifier and generates, optionally through a cross-validation procedure, the predictions
|
||||
needed for estimating the misclassification rates matrix.
|
||||
Estimates the misclassification rates.
|
||||
|
||||
:param data: the training set
|
||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param val_split: either a float in (0,1) indicating the proportion of training instances to use for
|
||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a LabelledCollection
|
||||
indicating the validation set itself, or an int indicating the number `k` of folds to be used in `k`-fold
|
||||
cross validation to estimate the parameters
|
||||
:return: self
|
||||
"""
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self.classifier, true_labels, pred_labels, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=False, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
return (true_labels, pred_labels)
|
||||
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: this is actually None
|
||||
:param classif_predictions: classifier predictions with true labels
|
||||
"""
|
||||
true_labels, pred_labels = classif_predictions
|
||||
self.cc = CC(self.classifier)
|
||||
|
@ -393,11 +330,7 @@ class PCC(AggregativeProbabilisticQuantifier):
|
|||
def __init__(self, classifier: BaseEstimator):
|
||||
self.classifier = classifier
|
||||
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier, ensure_probabilistic=True)
|
||||
return self
|
||||
|
||||
def aggregation_fit(self, classif_predictions: np.ndarray):
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
|
@ -429,33 +362,11 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
self.val_split = val_split
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
||||
"""
|
||||
Trains the soft classifier and generates, optionally through a cross-validation procedure, the posterior
|
||||
probabilities needed for estimating the misclassification rates matrix.
|
||||
Estimates the misclassification rates
|
||||
|
||||
:param data: the training set
|
||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param val_split: either a float in (0,1) indicating the proportion of training instances to use for
|
||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a LabelledCollection
|
||||
indicating the validation set itself, or an int indicating the number `k` of folds to be used in `k`-fold
|
||||
cross validation to estimate the parameters
|
||||
:return: self
|
||||
"""
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self.classifier, true_labels, posteriors, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
return (true_labels, posteriors)
|
||||
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: this is actually None
|
||||
:param classif_predictions: classifier predictions with true labels
|
||||
"""
|
||||
true_labels, posteriors = classif_predictions
|
||||
self.pcc = PCC(self.classifier)
|
||||
|
@ -509,7 +420,7 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
|||
self.exact_train_prev = exact_train_prev
|
||||
self.recalib = recalib
|
||||
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True):
|
||||
self.classifier, true_labels, posteriors, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
@ -842,7 +753,7 @@ class DMy(AggregativeProbabilisticQuantifier):
|
|||
distributions = np.cumsum(distributions, axis=1)
|
||||
return distributions
|
||||
|
||||
def classification_fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
||||
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
||||
"""
|
||||
Trains the classifier (if requested) and generates the validation distributions out of the training data.
|
||||
The validation distributions have shape `(n, ch, nbins)`, with `n` the number of classes, `ch` the number of
|
||||
|
|
|
@ -63,7 +63,7 @@ def newOneVsAll(binary_quantifier, n_jobs=None):
|
|||
return OneVsAllGeneric(binary_quantifier, n_jobs)
|
||||
|
||||
|
||||
class OneVsAllGeneric(OneVsAll,BaseQuantifier):
|
||||
class OneVsAllGeneric(OneVsAll, BaseQuantifier):
|
||||
"""
|
||||
Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
|
||||
quantifier for each class, and then l1-normalizes the outputs so that the class prevelence values sum up to 1.
|
||||
|
|
Loading…
Reference in New Issue