forked from moreo/QuaPy
fixing hyperparameters with prefixes, and replacing learner with classifier in aggregative quantifiers
This commit is contained in:
parent
adf799c8ec
commit
f9a199d859
6
TODO.txt
6
TODO.txt
|
@ -6,6 +6,12 @@ merge with master, because I had to fix some problems with QuaNet due to an issu
|
|||
added cross_val_predict in qp.model_selection (i.e., a cross_val_predict for quantification) --would be nice to have
|
||||
it parallelized
|
||||
|
||||
check the OneVsAll module(s)
|
||||
|
||||
check the set_params de neural.py, because the separation of estimator__<param> is not implemented; see also
|
||||
__check_params_colision
|
||||
|
||||
HDy can be customized so that the number of bins is specified, instead of explored within the fit method
|
||||
|
||||
Packaging:
|
||||
==========================================
|
||||
|
|
|
@ -17,7 +17,7 @@ training, val_generator, test_generator = fetch_lequa2022(task=task)
|
|||
|
||||
# define the quantifier
|
||||
learner = CalibratedClassifierCV(LogisticRegression())
|
||||
quantifier = EMQ(learner=learner)
|
||||
quantifier = EMQ(classifier=learner)
|
||||
|
||||
# model selection
|
||||
param_grid = {'C': np.logspace(-3, 3, 7), 'class_weight': ['balanced', None]}
|
||||
|
|
|
@ -4,7 +4,7 @@ from sklearn.calibration import CalibratedClassifierCV
|
|||
from sklearn.linear_model import LogisticRegression
|
||||
import quapy as qp
|
||||
import quapy.functional as F
|
||||
from classification.calibration import RecalibratedClassifierBase, NBVSCalibration, \
|
||||
from classification.calibration import RecalibratedProbabilisticClassifierBase, NBVSCalibration, \
|
||||
BCTSCalibration
|
||||
from data.datasets import LEQUA2022_SAMPLE_SIZE, fetch_lequa2022
|
||||
from evaluation import evaluation_report
|
||||
|
@ -13,7 +13,6 @@ from model_selection import GridSearchQ
|
|||
import pandas as pd
|
||||
|
||||
for task in ['T1A', 'T1B']:
|
||||
for calib in ['NoCal', 'TS', 'VS', 'NBVS', 'NBTS']:
|
||||
|
||||
# calibration = TempScaling(verbose=False, bias_positions='all')
|
||||
|
||||
|
@ -24,31 +23,36 @@ for task in ['T1A', 'T1B']:
|
|||
# learner = BCTSCalibration(LogisticRegression(), n_jobs=-1)
|
||||
# learner = CalibratedClassifierCV(LogisticRegression())
|
||||
learner = LogisticRegression()
|
||||
quantifier = EMQ(learner=learner, exact_train_prev=False, recalib=calib.lower() if calib != 'NoCal' else None)
|
||||
quantifier = EMQ(classifier=learner)
|
||||
|
||||
# model selection
|
||||
param_grid = {'C': np.logspace(-3, 3, 7), 'class_weight': ['balanced', None]}
|
||||
param_grid = {
|
||||
'classifier__C': np.logspace(-3, 3, 7),
|
||||
'classifier__class_weight': ['balanced', None],
|
||||
'recalib': ['platt', 'ts', 'vs', 'nbvs', 'bcts', None],
|
||||
'exact_train_prev': [False, True]
|
||||
}
|
||||
model_selection = GridSearchQ(quantifier, param_grid, protocol=val_generator, error='mrae', n_jobs=-1, refit=False, verbose=True)
|
||||
quantifier = model_selection.fit(training)
|
||||
|
||||
# evaluation
|
||||
report = evaluation_report(quantifier, protocol=test_generator, error_metrics=['mae', 'mrae', 'mkld'], verbose=True)
|
||||
|
||||
import os
|
||||
os.makedirs(f'./predictions/{task}', exist_ok=True)
|
||||
with open(f'./predictions/{task}/{calib}-EMQ.csv', 'wt') as foo:
|
||||
estim_prev = report['estim-prev'].values
|
||||
nclasses = len(estim_prev[0])
|
||||
foo.write(f'id,'+','.join([str(x) for x in range(nclasses)])+'\n')
|
||||
for id, prev in enumerate(estim_prev):
|
||||
foo.write(f'{id},'+','.join([f'{p:.5f}' for p in prev])+'\n')
|
||||
|
||||
os.makedirs(f'./errors/{task}', exist_ok=True)
|
||||
with open(f'./errors/{task}/{calib}-EMQ.csv', 'wt') as foo:
|
||||
maes, mraes = report['mae'].values, report['mrae'].values
|
||||
foo.write(f'id,AE,RAE\n')
|
||||
for id, (ae_i, rae_i) in enumerate(zip(maes, mraes)):
|
||||
foo.write(f'{id},{ae_i:.5f},{rae_i:.5f}\n')
|
||||
# import os
|
||||
# os.makedirs(f'./out', exist_ok=True)
|
||||
# with open(f'./out/EMQ_{calib}_{task}.txt', 'wt') as foo:
|
||||
# estim_prev = report['estim-prev'].values
|
||||
# nclasses = len(estim_prev[0])
|
||||
# foo.write(f'id,'+','.join([str(x) for x in range(nclasses)])+'\n')
|
||||
# for id, prev in enumerate(estim_prev):
|
||||
# foo.write(f'{id},'+','.join([f'{p:.5f}' for p in prev])+'\n')
|
||||
#
|
||||
# #os.makedirs(f'./errors/{task}', exist_ok=True)
|
||||
# with open(f'./out/EMQ_{calib}_{task}_errors.txt', 'wt') as foo:
|
||||
# maes, mraes = report['mae'].values, report['mrae'].values
|
||||
# foo.write(f'id,AE,RAE\n')
|
||||
# for id, (ae_i, rae_i) in enumerate(zip(maes, mraes)):
|
||||
# foo.write(f'{id},{ae_i:.5f},{rae_i:.5f}\n')
|
||||
|
||||
# printing results
|
||||
pd.set_option('display.expand_frame_repr', False)
|
||||
|
|
|
@ -37,6 +37,12 @@
|
|||
- new dependency "abstention" (to add to the project requirements and setup). Calibration methods from
|
||||
https://github.com/kundajelab/abstention added.
|
||||
|
||||
- the internal classifier of aggregative methods is now called "classifier" instead of "learner"
|
||||
|
||||
- when optimizing the hyperparameters of an aggregative quantifier, the classifier's specific hyperparameters
|
||||
should be marked with a "classifier__" prefix (just like in scikit-learn), while the quantifier's specific
|
||||
hyperparameters are named directly. For example, PCC(LogisticRegression()) quantifier has
|
||||
|
||||
Things to fix:
|
||||
- calibration with recalibration methods has to be fixed for exact_train_prev in EMQ (conflicts with clone, deepcopy, etc.)
|
||||
- clean functions like binary, aggregative, probabilistic, etc; those should be resolved via isinstance():
|
||||
|
|
|
@ -11,27 +11,27 @@ import numpy as np
|
|||
# see https://github.com/kundajelab/abstention
|
||||
|
||||
|
||||
class RecalibratedClassifier:
|
||||
class RecalibratedProbabilisticClassifier:
|
||||
pass
|
||||
|
||||
|
||||
class RecalibratedClassifierBase(BaseEstimator, RecalibratedClassifier):
|
||||
class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabilisticClassifier):
|
||||
"""
|
||||
Applies a (re)calibration method from abstention.calibration, as defined in
|
||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_:
|
||||
|
||||
:param estimator: a scikit-learn probabilistic classifier
|
||||
:param classifier: a scikit-learn probabilistic classifier
|
||||
:param calibrator: the calibration object (an instance of abstention.calibration.CalibratorFactory)
|
||||
:param val_split: indicate an integer k for performing kFCV to obtain the posterior prevalences, or a float p
|
||||
:param val_split: indicate an integer k for performing kFCV to obtain the posterior probabilities, or a float p
|
||||
in (0,1) to indicate that the posteriors are obtained in a stratified validation split containing p% of the
|
||||
training instances (the rest is used for training). In any case, the classifier is retrained in the whole
|
||||
training set afterwards.
|
||||
:param n_jobs: indicate the number of parallel workers (only when val_split is an integer)
|
||||
:param n_jobs: indicate the number of parallel workers (only when val_split is an integer); default=None
|
||||
:param verbose: whether or not to display information in the standard output
|
||||
"""
|
||||
|
||||
def __init__(self, estimator, calibrator, val_split=5, n_jobs=1, verbose=False):
|
||||
self.estimator = estimator
|
||||
def __init__(self, classifier, calibrator, val_split=5, n_jobs=None, verbose=False):
|
||||
self.classifier = classifier
|
||||
self.calibrator = calibrator
|
||||
self.val_split = val_split
|
||||
self.n_jobs = n_jobs
|
||||
|
@ -50,39 +50,39 @@ class RecalibratedClassifierBase(BaseEstimator, RecalibratedClassifier):
|
|||
|
||||
def fit_cv(self, X, y):
|
||||
posteriors = cross_val_predict(
|
||||
self.estimator, X, y, cv=self.val_split, n_jobs=self.n_jobs, verbose=self.verbose, method="predict_proba"
|
||||
self.classifier, X, y, cv=self.val_split, n_jobs=self.n_jobs, verbose=self.verbose, method='predict_proba'
|
||||
)
|
||||
self.estimator.fit(X, y)
|
||||
self.classifier.fit(X, y)
|
||||
nclasses = len(np.unique(y))
|
||||
self.calibration_function = self.calibrator(posteriors, np.eye(nclasses)[y], posterior_supplied=True)
|
||||
return self
|
||||
|
||||
def fit_tr_val(self, X, y):
|
||||
Xtr, Xva, ytr, yva = train_test_split(X, y, test_size=self.val_split, stratify=y)
|
||||
self.estimator.fit(Xtr, ytr)
|
||||
posteriors = self.estimator.predict_proba(Xva)
|
||||
self.classifier.fit(Xtr, ytr)
|
||||
posteriors = self.classifier.predict_proba(Xva)
|
||||
nclasses = len(np.unique(yva))
|
||||
self.calibrator = self.calibrator(posteriors, np.eye(nclasses)[yva], posterior_supplied=True)
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
return self.estimator.predict(X)
|
||||
return self.classifier.predict(X)
|
||||
|
||||
def predict_proba(self, X):
|
||||
posteriors = self.estimator.predict_proba(X)
|
||||
posteriors = self.classifier.predict_proba(X)
|
||||
return self.calibration_function(posteriors)
|
||||
|
||||
@property
|
||||
def classes_(self):
|
||||
return self.estimator.classes_
|
||||
return self.classifier.classes_
|
||||
|
||||
|
||||
class NBVSCalibration(RecalibratedClassifierBase):
|
||||
class NBVSCalibration(RecalibratedProbabilisticClassifierBase):
|
||||
"""
|
||||
Applies the No-Bias Vector Scaling (NBVS) calibration method from abstention.calibration, as defined in
|
||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_:
|
||||
|
||||
:param estimator: a scikit-learn probabilistic classifier
|
||||
:param classifier: a scikit-learn probabilistic classifier
|
||||
:param val_split: indicate an integer k for performing kFCV to obtain the posterior prevalences, or a float p
|
||||
in (0,1) to indicate that the posteriors are obtained in a stratified validation split containing p% of the
|
||||
training instances (the rest is used for training). In any case, the classifier is retrained in the whole
|
||||
|
@ -91,20 +91,20 @@ class NBVSCalibration(RecalibratedClassifierBase):
|
|||
:param verbose: whether or not to display information in the standard output
|
||||
"""
|
||||
|
||||
def __init__(self, estimator, val_split=5, n_jobs=1, verbose=False):
|
||||
self.estimator = estimator
|
||||
def __init__(self, classifier, val_split=5, n_jobs=1, verbose=False):
|
||||
self.classifier = classifier
|
||||
self.calibrator = NoBiasVectorScaling(verbose=verbose)
|
||||
self.val_split = val_split
|
||||
self.n_jobs = n_jobs
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
class BCTSCalibration(RecalibratedClassifierBase):
|
||||
class BCTSCalibration(RecalibratedProbabilisticClassifierBase):
|
||||
"""
|
||||
Applies the Bias-Corrected Temperature Scaling (BCTS) calibration method from abstention.calibration, as defined in
|
||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_:
|
||||
|
||||
:param estimator: a scikit-learn probabilistic classifier
|
||||
:param classifier: a scikit-learn probabilistic classifier
|
||||
:param val_split: indicate an integer k for performing kFCV to obtain the posterior prevalences, or a float p
|
||||
in (0,1) to indicate that the posteriors are obtained in a stratified validation split containing p% of the
|
||||
training instances (the rest is used for training). In any case, the classifier is retrained in the whole
|
||||
|
@ -113,20 +113,20 @@ class BCTSCalibration(RecalibratedClassifierBase):
|
|||
:param verbose: whether or not to display information in the standard output
|
||||
"""
|
||||
|
||||
def __init__(self, estimator, val_split=5, n_jobs=1, verbose=False):
|
||||
self.estimator = estimator
|
||||
def __init__(self, classifier, val_split=5, n_jobs=1, verbose=False):
|
||||
self.classifier = classifier
|
||||
self.calibrator = TempScaling(verbose=verbose, bias_positions='all')
|
||||
self.val_split = val_split
|
||||
self.n_jobs = n_jobs
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
class TSCalibration(RecalibratedClassifierBase):
|
||||
class TSCalibration(RecalibratedProbabilisticClassifierBase):
|
||||
"""
|
||||
Applies the Temperature Scaling (TS) calibration method from abstention.calibration, as defined in
|
||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_:
|
||||
|
||||
:param estimator: a scikit-learn probabilistic classifier
|
||||
:param classifier: a scikit-learn probabilistic classifier
|
||||
:param val_split: indicate an integer k for performing kFCV to obtain the posterior prevalences, or a float p
|
||||
in (0,1) to indicate that the posteriors are obtained in a stratified validation split containing p% of the
|
||||
training instances (the rest is used for training). In any case, the classifier is retrained in the whole
|
||||
|
@ -135,20 +135,20 @@ class TSCalibration(RecalibratedClassifierBase):
|
|||
:param verbose: whether or not to display information in the standard output
|
||||
"""
|
||||
|
||||
def __init__(self, estimator, val_split=5, n_jobs=1, verbose=False):
|
||||
self.estimator = estimator
|
||||
def __init__(self, classifier, val_split=5, n_jobs=1, verbose=False):
|
||||
self.classifier = classifier
|
||||
self.calibrator = TempScaling(verbose=verbose)
|
||||
self.val_split = val_split
|
||||
self.n_jobs = n_jobs
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
class VSCalibration(RecalibratedClassifierBase):
|
||||
class VSCalibration(RecalibratedProbabilisticClassifierBase):
|
||||
"""
|
||||
Applies the Vector Scaling (VS) calibration method from abstention.calibration, as defined in
|
||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_:
|
||||
|
||||
:param estimator: a scikit-learn probabilistic classifier
|
||||
:param classifier: a scikit-learn probabilistic classifier
|
||||
:param val_split: indicate an integer k for performing kFCV to obtain the posterior prevalences, or a float p
|
||||
in (0,1) to indicate that the posteriors are obtained in a stratified validation split containing p% of the
|
||||
training instances (the rest is used for training). In any case, the classifier is retrained in the whole
|
||||
|
@ -157,8 +157,8 @@ class VSCalibration(RecalibratedClassifierBase):
|
|||
:param verbose: whether or not to display information in the standard output
|
||||
"""
|
||||
|
||||
def __init__(self, estimator, val_split=5, n_jobs=1, verbose=False):
|
||||
self.estimator = estimator
|
||||
def __init__(self, classifier, val_split=5, n_jobs=1, verbose=False):
|
||||
self.classifier = classifier
|
||||
self.calibrator = VectorScaling(verbose=verbose)
|
||||
self.val_split = val_split
|
||||
self.n_jobs = n_jobs
|
||||
|
|
|
@ -10,7 +10,7 @@ from sklearn.model_selection import StratifiedKFold, cross_val_predict
|
|||
from tqdm import tqdm
|
||||
import quapy as qp
|
||||
import quapy.functional as F
|
||||
from classification.calibration import RecalibratedClassifier, NBVSCalibration, BCTSCalibration, TSCalibration, \
|
||||
from classification.calibration import RecalibratedProbabilisticClassifier, NBVSCalibration, BCTSCalibration, TSCalibration, \
|
||||
VSCalibration
|
||||
from quapy.classification.svmperf import SVMperf
|
||||
from quapy.data import LabelledCollection
|
||||
|
@ -23,41 +23,41 @@ from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
|||
class AggregativeQuantifier(BaseQuantifier):
|
||||
"""
|
||||
Abstract class for quantification methods that base their estimations on the aggregation of classification
|
||||
results. Aggregative Quantifiers thus implement a :meth:`classify` method and maintain a :attr:`learner` attribute.
|
||||
Subclasses of this abstract class must implement the method :meth:`aggregate` which computes the aggregation
|
||||
of label predictions. The method :meth:`quantify` comes with a default implementation based on
|
||||
:meth:`classify` and :meth:`aggregate`.
|
||||
results. Aggregative Quantifiers thus implement a :meth:`classify` method and maintain a :attr:`classifier`
|
||||
attribute. Subclasses of this abstract class must implement the method :meth:`aggregate` which computes the
|
||||
aggregation of label predictions. The method :meth:`quantify` comes with a default implementation based on
|
||||
:meth:`classify` and :meth:`aggregate`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
"""
|
||||
Trains the aggregative quantifier
|
||||
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param fit_learner: whether or not to train the learner (default is True). Set to False if the
|
||||
:param fit_classifier: whether or not to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:return: self
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def learner(self):
|
||||
def classifier(self):
|
||||
"""
|
||||
Gives access to the classifier
|
||||
|
||||
:return: the classifier (typically an sklearn's Estimator)
|
||||
"""
|
||||
return self.learner_
|
||||
return self.classifier_
|
||||
|
||||
@learner.setter
|
||||
def learner(self, classifier):
|
||||
@classifier.setter
|
||||
def classifier(self, classifier):
|
||||
"""
|
||||
Setter for the classifier
|
||||
|
||||
:param classifier: the classifier
|
||||
"""
|
||||
self.learner_ = classifier
|
||||
self.classifier_ = classifier
|
||||
|
||||
def classify(self, instances):
|
||||
"""
|
||||
|
@ -68,7 +68,7 @@ class AggregativeQuantifier(BaseQuantifier):
|
|||
:param instances: array-like
|
||||
:return: np.ndarray of shape `(n_instances,)` with label predictions
|
||||
"""
|
||||
return self.learner.predict(instances)
|
||||
return self.classifier.predict(instances)
|
||||
|
||||
def quantify(self, instances):
|
||||
"""
|
||||
|
@ -91,24 +91,24 @@ class AggregativeQuantifier(BaseQuantifier):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_params(self, deep=True):
|
||||
"""
|
||||
Return the current parameters of the quantifier.
|
||||
# def get_params(self, deep=True):
|
||||
# """
|
||||
# Return the current parameters of the quantifier.
|
||||
#
|
||||
# :param deep: for compatibility with sklearn
|
||||
# :return: a dictionary of param-value pairs
|
||||
# """
|
||||
#
|
||||
# return self.learner.get_params()
|
||||
|
||||
:param deep: for compatibility with sklearn
|
||||
:return: a dictionary of param-value pairs
|
||||
"""
|
||||
|
||||
return self.learner.get_params()
|
||||
|
||||
def set_params(self, **parameters):
|
||||
"""
|
||||
Set the parameters of the quantifier.
|
||||
|
||||
:param parameters: dictionary of param-value pairs
|
||||
"""
|
||||
|
||||
self.learner.set_params(**parameters)
|
||||
# def set_params(self, **parameters):
|
||||
# """
|
||||
# Set the parameters of the quantifier.
|
||||
#
|
||||
# :param parameters: dictionary of param-value pairs
|
||||
# """
|
||||
#
|
||||
# self.learner.set_params(**parameters)
|
||||
|
||||
@property
|
||||
def classes_(self):
|
||||
|
@ -118,7 +118,7 @@ class AggregativeQuantifier(BaseQuantifier):
|
|||
|
||||
:return: array-like
|
||||
"""
|
||||
return self.learner.classes_
|
||||
return self.classifier.classes_
|
||||
|
||||
|
||||
class AggregativeProbabilisticQuantifier(AggregativeQuantifier):
|
||||
|
@ -130,43 +130,43 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier):
|
|||
"""
|
||||
|
||||
def classify(self, instances):
|
||||
return self.learner.predict_proba(instances)
|
||||
return self.classifier.predict_proba(instances)
|
||||
|
||||
def set_params(self, **parameters):
|
||||
if isinstance(self.learner, CalibratedClassifierCV):
|
||||
if self.learner.get_params().get('base_estimator') == 'deprecated':
|
||||
key_prefix = 'estimator__' # this has changed in the newer versions of sklearn
|
||||
else:
|
||||
key_prefix = 'base_estimator__'
|
||||
parameters = {key_prefix + k: v for k, v in parameters.items()}
|
||||
elif isinstance(self.learner, RecalibratedClassifier):
|
||||
parameters = {'estimator__' + k: v for k, v in parameters.items()}
|
||||
|
||||
self.learner.set_params(**parameters)
|
||||
return self
|
||||
# def set_params(self, **parameters):
|
||||
# if isinstance(self.classifier, CalibratedClassifierCV):
|
||||
# if self.classifier.get_params().get('base_estimator') == 'deprecated':
|
||||
# key_prefix = 'estimator__' # this has changed in the newer versions of sklearn
|
||||
# else:
|
||||
# key_prefix = 'base_estimator__'
|
||||
# parameters = {key_prefix + k: v for k, v in parameters.items()}
|
||||
# elif isinstance(self.classifier, RecalibratedClassifier):
|
||||
# parameters = {'estimator__' + k: v for k, v in parameters.items()}
|
||||
#
|
||||
# self.classifier.set_params(**parameters)
|
||||
# return self
|
||||
|
||||
|
||||
# Helper
|
||||
# ------------------------------------
|
||||
def _ensure_probabilistic(learner):
|
||||
if not hasattr(learner, 'predict_proba'):
|
||||
print(f'The learner {learner.__class__.__name__} does not seem to be probabilistic. '
|
||||
def _ensure_probabilistic(classifier):
|
||||
if not hasattr(classifier, 'predict_proba'):
|
||||
print(f'The learner {classifier.__class__.__name__} does not seem to be probabilistic. '
|
||||
f'The learner will be calibrated.')
|
||||
learner = CalibratedClassifierCV(learner, cv=5)
|
||||
return learner
|
||||
classifier = CalibratedClassifierCV(classifier, cv=5)
|
||||
return classifier
|
||||
|
||||
|
||||
def _training_helper(learner,
|
||||
def _training_helper(classifier,
|
||||
data: LabelledCollection,
|
||||
fit_learner: bool = True,
|
||||
fit_classifier: bool = True,
|
||||
ensure_probabilistic=False,
|
||||
val_split: Union[LabelledCollection, float] = None):
|
||||
"""
|
||||
Training procedure common to all Aggregative Quantifiers.
|
||||
|
||||
:param learner: the learner to be fit
|
||||
:param classifier: the learner to be fit
|
||||
:param data: the data on which to fit the learner. If requested, the data will be split before fitting the learner.
|
||||
:param fit_learner: whether or not to fit the learner (if False, then bypasses any action)
|
||||
:param fit_classifier: whether or not to fit the learner (if False, then bypasses any action)
|
||||
:param ensure_probabilistic: if True, guarantees that the resulting classifier implements predict_proba (if the
|
||||
learner is not probabilistic, then a CalibratedCV instance of it is trained)
|
||||
:param val_split: if specified as a float, indicates the proportion of training instances that will define the
|
||||
|
@ -175,9 +175,9 @@ def _training_helper(learner,
|
|||
:return: the learner trained on the training set, and the unused data (a _LabelledCollection_ if train_val_split>0
|
||||
or None otherwise) to be used as a validation set for any subsequent parameter fitting
|
||||
"""
|
||||
if fit_learner:
|
||||
if fit_classifier:
|
||||
if ensure_probabilistic:
|
||||
learner = _ensure_probabilistic(learner)
|
||||
classifier = _ensure_probabilistic(classifier)
|
||||
if val_split is not None:
|
||||
if isinstance(val_split, float):
|
||||
if not (0 < val_split < 1):
|
||||
|
@ -193,72 +193,72 @@ def _training_helper(learner,
|
|||
else:
|
||||
train, unused = data, None
|
||||
|
||||
if isinstance(learner, BaseQuantifier):
|
||||
learner.fit(train)
|
||||
if isinstance(classifier, BaseQuantifier):
|
||||
classifier.fit(train)
|
||||
else:
|
||||
learner.fit(*train.Xy)
|
||||
classifier.fit(*train.Xy)
|
||||
else:
|
||||
if ensure_probabilistic:
|
||||
if not hasattr(learner, 'predict_proba'):
|
||||
raise AssertionError('error: the learner cannot be calibrated since fit_learner is set to False')
|
||||
if not hasattr(classifier, 'predict_proba'):
|
||||
raise AssertionError('error: the learner cannot be calibrated since fit_classifier is set to False')
|
||||
unused = None
|
||||
if isinstance(val_split, LabelledCollection):
|
||||
unused = val_split
|
||||
|
||||
return learner, unused
|
||||
return classifier, unused
|
||||
|
||||
|
||||
def cross_generate_predictions(
|
||||
data,
|
||||
learner,
|
||||
classifier,
|
||||
val_split,
|
||||
probabilistic,
|
||||
fit_learner,
|
||||
fit_classifier,
|
||||
n_jobs
|
||||
):
|
||||
|
||||
n_jobs = qp.get_njobs(n_jobs)
|
||||
|
||||
if isinstance(val_split, int):
|
||||
assert fit_learner == True, \
|
||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||
assert fit_classifier == True, \
|
||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_classifier=False'
|
||||
|
||||
if probabilistic:
|
||||
learner = _ensure_probabilistic(learner)
|
||||
classifier = _ensure_probabilistic(classifier)
|
||||
predict = 'predict_proba'
|
||||
else:
|
||||
predict = 'predict'
|
||||
y_pred = cross_val_predict(learner, *data.Xy, cv=val_split, n_jobs=n_jobs, method=predict)
|
||||
y_pred = cross_val_predict(classifier, *data.Xy, cv=val_split, n_jobs=n_jobs, method=predict)
|
||||
class_count = data.counts()
|
||||
|
||||
# fit the learner on all data
|
||||
learner.fit(*data.Xy)
|
||||
classifier.fit(*data.Xy)
|
||||
y = data.y
|
||||
classes = data.classes_
|
||||
else:
|
||||
learner, val_data = _training_helper(
|
||||
learner, data, fit_learner, ensure_probabilistic=probabilistic, val_split=val_split
|
||||
classifier, val_data = _training_helper(
|
||||
classifier, data, fit_classifier, ensure_probabilistic=probabilistic, val_split=val_split
|
||||
)
|
||||
y_pred = learner.predict_proba(val_data.instances) if probabilistic else learner.predict(val_data.instances)
|
||||
y_pred = classifier.predict_proba(val_data.instances) if probabilistic else classifier.predict(val_data.instances)
|
||||
y = val_data.labels
|
||||
classes = val_data.classes_
|
||||
class_count = val_data.counts()
|
||||
|
||||
return learner, y, y_pred, classes, class_count
|
||||
return classifier, y, y_pred, classes, class_count
|
||||
|
||||
|
||||
def cross_generate_predictions_depr(
|
||||
data,
|
||||
learner,
|
||||
classifier,
|
||||
val_split,
|
||||
probabilistic,
|
||||
fit_learner,
|
||||
fit_classifier,
|
||||
method_name=''
|
||||
):
|
||||
predict = learner.predict_proba if probabilistic else learner.predict
|
||||
predict = classifier.predict_proba if probabilistic else classifier.predict
|
||||
if isinstance(val_split, int):
|
||||
assert fit_learner == True, \
|
||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||
assert fit_classifier == True, \
|
||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_classifier=False'
|
||||
# kFCV estimation of parameters
|
||||
y, y_ = [], []
|
||||
kfcv = StratifiedKFold(n_splits=val_split)
|
||||
|
@ -267,8 +267,8 @@ def cross_generate_predictions_depr(
|
|||
pbar.set_description(f'{method_name}\tfitting fold {k}')
|
||||
training = data.sampling_from_index(training_idx)
|
||||
validation = data.sampling_from_index(validation_idx)
|
||||
learner, val_data = _training_helper(
|
||||
learner, training, fit_learner, ensure_probabilistic=probabilistic, val_split=validation
|
||||
classifier, val_data = _training_helper(
|
||||
classifier, training, fit_classifier, ensure_probabilistic=probabilistic, val_split=validation
|
||||
)
|
||||
y_.append(predict(val_data.instances))
|
||||
y.append(val_data.labels)
|
||||
|
@ -278,21 +278,21 @@ def cross_generate_predictions_depr(
|
|||
class_count = data.counts()
|
||||
|
||||
# fit the learner on all data
|
||||
learner, _ = _training_helper(
|
||||
learner, data, fit_learner, ensure_probabilistic=probabilistic, val_split=None
|
||||
classifier, _ = _training_helper(
|
||||
classifier, data, fit_classifier, ensure_probabilistic=probabilistic, val_split=None
|
||||
)
|
||||
classes = data.classes_
|
||||
|
||||
else:
|
||||
learner, val_data = _training_helper(
|
||||
learner, data, fit_learner, ensure_probabilistic=probabilistic, val_split=val_split
|
||||
classifier, val_data = _training_helper(
|
||||
classifier, data, fit_classifier, ensure_probabilistic=probabilistic, val_split=val_split
|
||||
)
|
||||
y_ = predict(val_data.instances)
|
||||
y = val_data.labels
|
||||
classes = val_data.classes_
|
||||
class_count = val_data.counts()
|
||||
|
||||
return learner, y, y_, classes, class_count
|
||||
return classifier, y, y_, classes, class_count
|
||||
|
||||
# Methods
|
||||
# ------------------------------------
|
||||
|
@ -301,22 +301,22 @@ class CC(AggregativeQuantifier):
|
|||
The most basic Quantification method. One that simply classifies all instances and counts how many have been
|
||||
attributed to each of the classes in order to compute class prevalence estimates.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator):
|
||||
self.classifier = classifier
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
"""
|
||||
Trains the Classify & Count method unless `fit_learner` is False, in which case, the classifier is assumed to
|
||||
Trains the Classify & Count method unless `fit_classifier` is False, in which case, the classifier is assumed to
|
||||
be already fit and there is nothing else to do.
|
||||
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param fit_learner: if False, the classifier is assumed to be fit
|
||||
:param fit_classifier: if False, the classifier is assumed to be fit
|
||||
:return: self
|
||||
"""
|
||||
self.learner, _ = _training_helper(self.learner, data, fit_learner)
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier)
|
||||
return self
|
||||
|
||||
def aggregate(self, classif_predictions: np.ndarray):
|
||||
|
@ -335,7 +335,7 @@ class ACC(AggregativeQuantifier):
|
|||
the "adjusted" variant of :class:`CC`, that corrects the predictions of CC
|
||||
according to the `misclassification rates`.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
||||
|
@ -344,17 +344,17 @@ class ACC(AggregativeQuantifier):
|
|||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4, n_jobs=None):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4, n_jobs=None):
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
self.n_jobs = qp.get_njobs(n_jobs)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
"""
|
||||
Trains a ACC quantifier.
|
||||
|
||||
:param data: the training set
|
||||
:param fit_learner: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param val_split: either a float in (0,1) indicating the proportion of training instances to use for
|
||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a LabelledCollection
|
||||
indicating the validation set itself, or an int indicating the number `k` of folds to be used in `k`-fold
|
||||
|
@ -365,11 +365,11 @@ class ACC(AggregativeQuantifier):
|
|||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self.learner, y, y_, classes, class_count = cross_generate_predictions(
|
||||
data, self.learner, val_split, probabilistic=False, fit_learner=fit_learner, n_jobs=self.n_jobs
|
||||
self.classifier, y, y_, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=False, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
self.cc = CC(self.learner)
|
||||
self.cc = CC(self.classifier)
|
||||
self.Pte_cond_estim_ = self.getPteCondEstim(data.classes_, y, y_)
|
||||
|
||||
return self
|
||||
|
@ -422,14 +422,14 @@ class PCC(AggregativeProbabilisticQuantifier):
|
|||
`Probabilistic Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
|
||||
the probabilistic variant of CC that relies on the posterior probabilities returned by a probabilistic classifier.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator):
|
||||
self.classifier = classifier
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
self.learner, _ = _training_helper(self.learner, data, fit_learner, ensure_probabilistic=True)
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier, ensure_probabilistic=True)
|
||||
return self
|
||||
|
||||
def aggregate(self, classif_posteriors):
|
||||
|
@ -441,7 +441,7 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
`Probabilistic Adjusted Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
|
||||
the probabilistic variant of ACC that relies on the posterior probabilities returned by a probabilistic classifier.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
||||
|
@ -451,17 +451,17 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
:param n_jobs: number of parallel workers
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4, n_jobs=None):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4, n_jobs=None):
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
self.n_jobs = qp.get_njobs(n_jobs)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
"""
|
||||
Trains a PACC quantifier.
|
||||
|
||||
:param data: the training set
|
||||
:param fit_learner: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param val_split: either a float in (0,1) indicating the proportion of training instances to use for
|
||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a LabelledCollection
|
||||
indicating the validation set itself, or an int indicating the number k of folds to be used in kFCV
|
||||
|
@ -472,11 +472,11 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self.learner, y, y_, classes, class_count = cross_generate_predictions(
|
||||
data, self.learner, val_split, probabilistic=True, fit_learner=fit_learner, n_jobs=self.n_jobs
|
||||
self.classifier, y, y_, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
self.pcc = PCC(self.learner)
|
||||
self.pcc = PCC(self.classifier)
|
||||
self.Pte_cond_estim_ = self.getPteCondEstim(classes, y, y_)
|
||||
|
||||
return self
|
||||
|
@ -510,7 +510,7 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
|||
probabilities generated by a probabilistic classifier and the class prevalence estimates obtained via
|
||||
maximum-likelihood estimation, in a mutually recursive way, until convergence.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param exact_train_prev: set to True (default) for using, as the initial observation, the true training prevalence;
|
||||
or set to False for computing the training prevalence as an estimate, akin to PCC, i.e., as the expected
|
||||
value of the posterior probabilities of the training instances as suggested in
|
||||
|
@ -523,30 +523,32 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
|||
MAX_ITER = 1000
|
||||
EPSILON = 1e-4
|
||||
|
||||
def __init__(self, learner: BaseEstimator, exact_train_prev=True, recalib=None):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator, exact_train_prev=True, recalib=None):
|
||||
self.classifier = classifier
|
||||
self.exact_train_prev = exact_train_prev
|
||||
self.recalib = recalib
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
if self.recalib is not None:
|
||||
if self.recalib == 'nbvs':
|
||||
self.learner = NBVSCalibration(self.learner)
|
||||
self.classifier = NBVSCalibration(self.classifier)
|
||||
elif self.recalib == 'bcts':
|
||||
self.learner = BCTSCalibration(self.learner)
|
||||
self.classifier = BCTSCalibration(self.classifier)
|
||||
elif self.recalib == 'ts':
|
||||
self.learner = TSCalibration(self.learner)
|
||||
self.classifier = TSCalibration(self.classifier)
|
||||
elif self.recalib == 'vs':
|
||||
self.learner = VSCalibration(self.learner)
|
||||
self.classifier = VSCalibration(self.classifier)
|
||||
elif self.recalib == 'platt':
|
||||
self.classifier = CalibratedClassifierCV(self.classifier, ensemble=False)
|
||||
else:
|
||||
raise ValueError('invalid param argument for recalibration method; available ones are '
|
||||
'"nbvs", "bcts", "ts", and "vs".')
|
||||
self.learner, _ = _training_helper(self.learner, data, fit_learner, ensure_probabilistic=True)
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier, ensure_probabilistic=True)
|
||||
if self.exact_train_prev:
|
||||
self.train_prevalence = F.prevalence_from_labels(data.labels, self.classes_)
|
||||
else:
|
||||
self.train_prevalence = qp.model_selection.cross_val_predict(
|
||||
quantifier=PCC(deepcopy(self.learner)),
|
||||
quantifier=PCC(deepcopy(self.classifier)),
|
||||
data=data,
|
||||
nfolds=3,
|
||||
random_state=0
|
||||
|
@ -558,7 +560,7 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
|||
return priors
|
||||
|
||||
def predict_proba(self, instances, epsilon=EPSILON):
|
||||
classif_posteriors = self.learner.predict_proba(instances)
|
||||
classif_posteriors = self.classifier.predict_proba(instances)
|
||||
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||
return posteriors
|
||||
|
||||
|
@ -611,21 +613,21 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
|||
class-conditional distributions of the posterior probabilities returned for the positive and negative validation
|
||||
examples, respectively. The parameters of the mixture thus represent the estimates of the class prevalence values.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a binary classifier
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier
|
||||
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
|
||||
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4):
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, LabelledCollection] = None):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
||||
"""
|
||||
Trains a HDy quantifier.
|
||||
|
||||
:param data: the training set
|
||||
:param fit_learner: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
||||
:param val_split: either a float in (0,1) indicating the proportion of training instances to use for
|
||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a
|
||||
:class:`quapy.data.base.LabelledCollection` indicating the validation set itself
|
||||
|
@ -635,11 +637,11 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
|||
val_split = self.val_split
|
||||
|
||||
self._check_binary(data, self.__class__.__name__)
|
||||
self.learner, validation = _training_helper(
|
||||
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
||||
self.classifier, validation = _training_helper(
|
||||
self.classifier, data, fit_classifier, ensure_probabilistic=True, val_split=val_split)
|
||||
Px = self.classify(validation.instances)[:, 1] # takes only the P(y=+1|x)
|
||||
self.Pxy1 = Px[validation.labels == self.learner.classes_[1]]
|
||||
self.Pxy0 = Px[validation.labels == self.learner.classes_[0]]
|
||||
self.Pxy1 = Px[validation.labels == self.classifier.classes_[1]]
|
||||
self.Pxy0 = Px[validation.labels == self.classifier.classes_[0]]
|
||||
# pre-compute the histogram for positive and negative examples
|
||||
self.bins = np.linspace(10, 110, 11, dtype=int) # [10, 20, 30, ..., 100, 110]
|
||||
self.Pxy1_density = {bins: np.histogram(self.Pxy1, bins=bins, range=(0, 1), density=True)[0] for bins in
|
||||
|
@ -684,7 +686,7 @@ class DyS(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
|||
minimizes the distance between distributions.
|
||||
Details for the ternary search have been got from <https://dl.acm.org/doi/pdf/10.1145/3219819.3220059>
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a binary classifier
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier
|
||||
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
|
||||
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param n_bins: an int with the number of bins to use to compute the histograms.
|
||||
|
@ -693,8 +695,8 @@ class DyS(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
|||
:param tol: a float with the tolerance for the ternary search algorithm.
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4, n_bins=8, distance: Union[str, Callable]='HD', tol=1e-05):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4, n_bins=8, distance: Union[str, Callable]='HD', tol=1e-05):
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
self.tol = tol
|
||||
self.distance = distance
|
||||
|
@ -717,23 +719,23 @@ class DyS(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
|||
return (left + right) / 2
|
||||
|
||||
def _compute_distance(self, Px_train, Px_test, distance: Union[str, Callable]='HD'):
|
||||
if distance=='HD':
|
||||
if distance == 'HD':
|
||||
return F.HellingerDistance(Px_train, Px_test)
|
||||
elif distance=='topsoe':
|
||||
elif distance == 'topsoe':
|
||||
return F.TopsoeDistance(Px_train, Px_test)
|
||||
else:
|
||||
return distance(Px_train, Px_test)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, LabelledCollection] = None):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self._check_binary(data, self.__class__.__name__)
|
||||
self.learner, validation = _training_helper(
|
||||
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
||||
self.classifier, validation = _training_helper(
|
||||
self.classifier, data, fit_classifier, ensure_probabilistic=True, val_split=val_split)
|
||||
Px = self.classify(validation.instances)[:, 1] # takes only the P(y=+1|x)
|
||||
self.Pxy1 = Px[validation.labels == self.learner.classes_[1]]
|
||||
self.Pxy0 = Px[validation.labels == self.learner.classes_[0]]
|
||||
self.Pxy1 = Px[validation.labels == self.classifier.classes_[1]]
|
||||
self.Pxy0 = Px[validation.labels == self.classifier.classes_[0]]
|
||||
self.Pxy1_density = np.histogram(self.Pxy1, bins=self.n_bins, range=(0, 1), density=True)[0]
|
||||
self.Pxy0_density = np.histogram(self.Pxy0, bins=self.n_bins, range=(0, 1), density=True)[0]
|
||||
return self
|
||||
|
@ -757,25 +759,25 @@ class SMM(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
|||
SMM is a simplification of matching distribution methods where the representation of the examples
|
||||
is created using the mean instead of a histogram.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a binary classifier.
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier.
|
||||
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
|
||||
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4):
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, LabelledCollection] = None):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self._check_binary(data, self.__class__.__name__)
|
||||
self.learner, validation = _training_helper(
|
||||
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
||||
self.classifier, validation = _training_helper(
|
||||
self.classifier, data, fit_classifier, ensure_probabilistic=True, val_split=val_split)
|
||||
Px = self.classify(validation.instances)[:, 1] # takes only the P(y=+1|x)
|
||||
self.Pxy1 = Px[validation.labels == self.learner.classes_[1]]
|
||||
self.Pxy0 = Px[validation.labels == self.learner.classes_[0]]
|
||||
self.Pxy1 = Px[validation.labels == self.classifier.classes_[1]]
|
||||
self.Pxy0 = Px[validation.labels == self.classifier.classes_[0]]
|
||||
self.Pxy1_mean = np.mean(self.Pxy1)
|
||||
self.Pxy0_mean = np.mean(self.Pxy0)
|
||||
return self
|
||||
|
@ -809,19 +811,19 @@ class ELM(AggregativeQuantifier, BinaryQuantifier):
|
|||
self.svmperf_base = svmperf_base if svmperf_base is not None else qp.environ['SVMPERF_HOME']
|
||||
self.loss = loss
|
||||
self.kwargs = kwargs
|
||||
self.learner = SVMperf(self.svmperf_base, loss=self.loss, **self.kwargs)
|
||||
self.classifier = SVMperf(self.svmperf_base, loss=self.loss, **self.kwargs)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
self._check_binary(data, self.__class__.__name__)
|
||||
assert fit_learner, 'the method requires that fit_learner=True'
|
||||
self.learner.fit(data.instances, data.labels)
|
||||
assert fit_classifier, 'the method requires that fit_classifier=True'
|
||||
self.classifier.fit(data.instances, data.labels)
|
||||
return self
|
||||
|
||||
def aggregate(self, classif_predictions: np.ndarray):
|
||||
return F.prevalence_from_labels(classif_predictions, self.classes_)
|
||||
|
||||
def classify(self, X, y=None):
|
||||
return self.learner.predict(X)
|
||||
return self.classifier.predict(X)
|
||||
|
||||
|
||||
class SVMQ(ELM):
|
||||
|
@ -916,7 +918,7 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
|
|||
that would allow for more true positives and many more false positives, on the grounds this
|
||||
would deliver larger denominators.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
||||
|
@ -925,22 +927,22 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
|
|||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4, n_jobs=None):
|
||||
self.learner = learner
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4, n_jobs=None):
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
self.n_jobs = qp.get_njobs(n_jobs)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
self._check_binary(data, "Threshold Optimization")
|
||||
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
self.learner, y, y_, classes, class_count = cross_generate_predictions(
|
||||
data, self.learner, val_split, probabilistic=True, fit_learner=fit_learner, n_jobs=self.n_jobs
|
||||
self.classifier, y, y_, classes, class_count = cross_generate_predictions(
|
||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
self.cc = CC(self.learner)
|
||||
self.cc = CC(self.classifier)
|
||||
|
||||
self.tpr, self.fpr = self._optimize_threshold(y, y_)
|
||||
|
||||
|
@ -1018,7 +1020,7 @@ class T50(ThresholdOptimization):
|
|||
for the threshold that makes `tpr` cosest to 0.5.
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
||||
|
@ -1027,8 +1029,8 @@ class T50(ThresholdOptimization):
|
|||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||
super().__init__(learner, val_split)
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4):
|
||||
super().__init__(classifier, val_split)
|
||||
|
||||
def _condition(self, tpr, fpr) -> float:
|
||||
return abs(tpr - 0.5)
|
||||
|
@ -1042,7 +1044,7 @@ class MAX(ThresholdOptimization):
|
|||
for the threshold that maximizes `tpr-fpr`.
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
||||
|
@ -1051,8 +1053,8 @@ class MAX(ThresholdOptimization):
|
|||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||
super().__init__(learner, val_split)
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4):
|
||||
super().__init__(classifier, val_split)
|
||||
|
||||
def _condition(self, tpr, fpr) -> float:
|
||||
# MAX strives to maximize (tpr - fpr), which is equivalent to minimize (fpr - tpr)
|
||||
|
@ -1067,7 +1069,7 @@ class X(ThresholdOptimization):
|
|||
for the threshold that yields `tpr=1-fpr`.
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
||||
|
@ -1076,8 +1078,8 @@ class X(ThresholdOptimization):
|
|||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||
super().__init__(learner, val_split)
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4):
|
||||
super().__init__(classifier, val_split)
|
||||
|
||||
def _condition(self, tpr, fpr) -> float:
|
||||
return abs(1 - (tpr + fpr))
|
||||
|
@ -1091,7 +1093,7 @@ class MS(ThresholdOptimization):
|
|||
class prevalence estimates for all decision thresholds and returns the median of them all.
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
||||
|
@ -1099,8 +1101,8 @@ class MS(ThresholdOptimization):
|
|||
`k`-fold cross validation (this integer stands for the number of folds `k`), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||
super().__init__(learner, val_split)
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4):
|
||||
super().__init__(classifier, val_split)
|
||||
|
||||
def _condition(self, tpr, fpr) -> float:
|
||||
pass
|
||||
|
@ -1128,7 +1130,7 @@ class MS2(MS):
|
|||
which `tpr-fpr>0.25`
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
||||
|
@ -1136,8 +1138,8 @@ class MS2(MS):
|
|||
`k`-fold cross validation (this integer stands for the number of folds `k`), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
"""
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||
super().__init__(learner, val_split)
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4):
|
||||
super().__init__(classifier, val_split)
|
||||
|
||||
def _optimize_threshold(self, y, probabilities):
|
||||
tprs = [0, 1]
|
||||
|
@ -1174,7 +1176,8 @@ class OneVsAll(AggregativeQuantifier):
|
|||
This variant was used, along with the :class:`EMQ` quantifier, in
|
||||
`Gao and Sebastiani, 2016 <https://link.springer.com/content/pdf/10.1007/s13278-016-0327-z.pdf>`_.
|
||||
|
||||
:param learner: a sklearn's Estimator that generates a binary classifier
|
||||
:param binary_quantifier: a quantifier (binary) that will be employed to work on multiclass model in a
|
||||
one-vs-all manner
|
||||
:param n_jobs: number of parallel workers
|
||||
"""
|
||||
|
||||
|
@ -1186,11 +1189,11 @@ class OneVsAll(AggregativeQuantifier):
|
|||
self.binary_quantifier = binary_quantifier
|
||||
self.n_jobs = qp.get_njobs(n_jobs)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
assert not data.binary, \
|
||||
f'{self.__class__.__name__} expect non-binary data'
|
||||
assert fit_learner == True, \
|
||||
'fit_learner must be True'
|
||||
assert fit_classifier == True, \
|
||||
'fit_classifier must be True'
|
||||
|
||||
self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
|
||||
self.__parallel(self._delayed_binary_fit, data)
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
import quapy as qp
|
||||
from quapy.data import LabelledCollection
|
||||
|
||||
|
||||
# Base Quantifier abstract class
|
||||
# ------------------------------------
|
||||
class BaseQuantifier(metaclass=ABCMeta):
|
||||
class BaseQuantifier(BaseEstimator):
|
||||
"""
|
||||
Abstract Quantifier. A quantifier is defined as an object of a class that implements the method :meth:`fit` on
|
||||
:class:`quapy.data.base.LabelledCollection`, the method :meth:`quantify`, and the :meth:`set_params` and
|
||||
|
@ -33,24 +36,24 @@ class BaseQuantifier(metaclass=ABCMeta):
|
|||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_params(self, **parameters):
|
||||
"""
|
||||
Set the parameters of the quantifier.
|
||||
|
||||
:param parameters: dictionary of param-value pairs
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_params(self, deep=True):
|
||||
"""
|
||||
Return the current parameters of the quantifier.
|
||||
|
||||
:param deep: for compatibility with sklearn
|
||||
:return: a dictionary of param-value pairs
|
||||
"""
|
||||
...
|
||||
# @abstractmethod
|
||||
# def set_params(self, **parameters):
|
||||
# """
|
||||
# Set the parameters of the quantifier.
|
||||
#
|
||||
# :param parameters: dictionary of param-value pairs
|
||||
# """
|
||||
# ...
|
||||
#
|
||||
# @abstractmethod
|
||||
# def get_params(self, deep=True):
|
||||
# """
|
||||
# Return the current parameters of the quantifier.
|
||||
#
|
||||
# :param deep: for compatibility with sklearn
|
||||
# :return: a dictionary of param-value pairs
|
||||
# """
|
||||
# ...
|
||||
|
||||
|
||||
class BinaryQuantifier(BaseQuantifier):
|
||||
|
@ -67,7 +70,7 @@ class BinaryQuantifier(BaseQuantifier):
|
|||
class OneVsAllGeneric:
|
||||
"""
|
||||
Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
|
||||
quantifier for each class, and then l1-normalizes the outputs so that the class prevelences sum up to 1.
|
||||
quantifier for each class, and then l1-normalizes the outputs so that the class prevelence values sum up to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, binary_quantifier, n_jobs=None):
|
||||
|
@ -103,11 +106,11 @@ class OneVsAllGeneric:
|
|||
def get_params(self, deep=True):
|
||||
return self.binary_quantifier.get_params()
|
||||
|
||||
def _delayed_binary_predict(self, c, learners, X):
|
||||
return learners[c].quantify(X)[:,1] # the mean is the estimation for the positive class prevalence
|
||||
def _delayed_binary_predict(self, c, quantifiers, X):
|
||||
return quantifiers[c].quantify(X)[:, 1] # the mean is the estimation for the positive class prevalence
|
||||
|
||||
def _delayed_binary_fit(self, c, learners, data, **kwargs):
|
||||
def _delayed_binary_fit(self, c, quantifiers, data, **kwargs):
|
||||
bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
|
||||
learners[c].fit(bindata, **kwargs)
|
||||
quantifiers[c].fit(bindata, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -146,7 +146,7 @@ class Ensemble(BaseQuantifier):
|
|||
This function should not be used within :class:`quapy.model_selection.GridSearchQ` (is here for compatibility
|
||||
with the abstract class).
|
||||
Instead, use `Ensemble(GridSearchQ(q),...)`, with `q` a Quantifier (recommended), or
|
||||
`Ensemble(Q(GridSearchCV(l)))` with `Q` a quantifier class that has a learner `l` optimized for
|
||||
`Ensemble(Q(GridSearchCV(l)))` with `Q` a quantifier class that has a classifier `l` optimized for
|
||||
classification (not recommended).
|
||||
|
||||
:param parameters: dictionary
|
||||
|
@ -154,7 +154,7 @@ class Ensemble(BaseQuantifier):
|
|||
"""
|
||||
raise NotImplementedError(f'{self.__class__.__name__} should not be used within GridSearchQ; '
|
||||
f'instead, use Ensemble(GridSearchQ(q),...), with q a Quantifier (recommended), '
|
||||
f'or Ensemble(Q(GridSearchCV(l))) with Q a quantifier class that has a learner '
|
||||
f'or Ensemble(Q(GridSearchCV(l))) with Q a quantifier class that has a classifier '
|
||||
f'l optimized for classification (not recommended).')
|
||||
|
||||
def get_params(self, deep=True):
|
||||
|
@ -162,7 +162,7 @@ class Ensemble(BaseQuantifier):
|
|||
This function should not be used within :class:`quapy.model_selection.GridSearchQ` (is here for compatibility
|
||||
with the abstract class).
|
||||
Instead, use `Ensemble(GridSearchQ(q),...)`, with `q` a Quantifier (recommended), or
|
||||
`Ensemble(Q(GridSearchCV(l)))` with `Q` a quantifier class that has a learner `l` optimized for
|
||||
`Ensemble(Q(GridSearchCV(l)))` with `Q` a quantifier class that has a classifier `l` optimized for
|
||||
classification (not recommended).
|
||||
|
||||
:return: raises an Exception
|
||||
|
@ -326,18 +326,18 @@ def _draw_simplex(ndim, min_val, max_trials=100):
|
|||
f'>= {min_val} is unlikely (it failed after {max_trials} trials)')
|
||||
|
||||
|
||||
def _instantiate_ensemble(learner, base_quantifier_class, param_grid, optim, param_model_sel, **kwargs):
|
||||
def _instantiate_ensemble(classifier, base_quantifier_class, param_grid, optim, param_model_sel, **kwargs):
|
||||
if optim is None:
|
||||
base_quantifier = base_quantifier_class(learner)
|
||||
base_quantifier = base_quantifier_class(classifier)
|
||||
elif optim in qp.error.CLASSIFICATION_ERROR:
|
||||
if optim == qp.error.f1e:
|
||||
scoring = make_scorer(f1_score)
|
||||
elif optim == qp.error.acce:
|
||||
scoring = make_scorer(accuracy_score)
|
||||
learner = GridSearchCV(learner, param_grid, scoring=scoring)
|
||||
base_quantifier = base_quantifier_class(learner)
|
||||
classifier = GridSearchCV(classifier, param_grid, scoring=scoring)
|
||||
base_quantifier = base_quantifier_class(classifier)
|
||||
else:
|
||||
base_quantifier = GridSearchQ(base_quantifier_class(learner),
|
||||
base_quantifier = GridSearchQ(base_quantifier_class(classifier),
|
||||
param_grid=param_grid,
|
||||
**param_model_sel,
|
||||
error=optim)
|
||||
|
@ -357,7 +357,7 @@ def _check_error(error):
|
|||
f'the name of an error function in {qp.error.ERROR_NAMES}')
|
||||
|
||||
|
||||
def ensembleFactory(learner, base_quantifier_class, param_grid=None, optim=None, param_model_sel: dict = None,
|
||||
def ensembleFactory(classifier, base_quantifier_class, param_grid=None, optim=None, param_model_sel: dict = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Ensemble factory. Provides a unified interface for instantiating ensembles that can be optimized (via model
|
||||
|
@ -390,7 +390,7 @@ def ensembleFactory(learner, base_quantifier_class, param_grid=None, optim=None,
|
|||
>>>
|
||||
>>> ensembleFactory(LogisticRegression(), PACC, optim='mae', policy='mae', **common)
|
||||
|
||||
:param learner: sklearn's Estimator that generates a classifier
|
||||
:param classifier: sklearn's Estimator that generates a classifier
|
||||
:param base_quantifier_class: a class of quantifiers
|
||||
:param param_grid: a dictionary with the grid of parameters to optimize for
|
||||
:param optim: a valid quantification or classification error, or a string name of it
|
||||
|
@ -405,21 +405,21 @@ def ensembleFactory(learner, base_quantifier_class, param_grid=None, optim=None,
|
|||
if param_model_sel is None:
|
||||
raise ValueError(f'param_model_sel is None but optim was requested.')
|
||||
error = _check_error(optim)
|
||||
return _instantiate_ensemble(learner, base_quantifier_class, param_grid, error, param_model_sel, **kwargs)
|
||||
return _instantiate_ensemble(classifier, base_quantifier_class, param_grid, error, param_model_sel, **kwargs)
|
||||
|
||||
|
||||
def ECC(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
def ECC(classifier, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
"""
|
||||
Implements an ensemble of :class:`quapy.method.aggregative.CC` quantifiers, as used by
|
||||
`Pérez-Gállego et al., 2019 <https://www.sciencedirect.com/science/article/pii/S1566253517303652>`_.
|
||||
|
||||
Equivalent to:
|
||||
|
||||
>>> ensembleFactory(learner, CC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
>>> ensembleFactory(classifier, CC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
See :meth:`ensembleFactory` for further details.
|
||||
|
||||
:param learner: sklearn's Estimator that generates a classifier
|
||||
:param classifier: sklearn's Estimator that generates a classifier
|
||||
:param param_grid: a dictionary with the grid of parameters to optimize for
|
||||
:param optim: a valid quantification or classification error, or a string name of it
|
||||
:param param_model_sel: a dictionary containing any keyworded argument to pass to
|
||||
|
@ -428,21 +428,21 @@ def ECC(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
|||
:return: an instance of :class:`Ensemble`
|
||||
"""
|
||||
|
||||
return ensembleFactory(learner, CC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
return ensembleFactory(classifier, CC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
|
||||
def EACC(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
def EACC(classifier, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
"""
|
||||
Implements an ensemble of :class:`quapy.method.aggregative.ACC` quantifiers, as used by
|
||||
`Pérez-Gállego et al., 2019 <https://www.sciencedirect.com/science/article/pii/S1566253517303652>`_.
|
||||
|
||||
Equivalent to:
|
||||
|
||||
>>> ensembleFactory(learner, ACC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
>>> ensembleFactory(classifier, ACC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
See :meth:`ensembleFactory` for further details.
|
||||
|
||||
:param learner: sklearn's Estimator that generates a classifier
|
||||
:param classifier: sklearn's Estimator that generates a classifier
|
||||
:param param_grid: a dictionary with the grid of parameters to optimize for
|
||||
:param optim: a valid quantification or classification error, or a string name of it
|
||||
:param param_model_sel: a dictionary containing any keyworded argument to pass to
|
||||
|
@ -451,20 +451,20 @@ def EACC(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
|||
:return: an instance of :class:`Ensemble`
|
||||
"""
|
||||
|
||||
return ensembleFactory(learner, ACC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
return ensembleFactory(classifier, ACC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
|
||||
def EPACC(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
def EPACC(classifier, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
"""
|
||||
Implements an ensemble of :class:`quapy.method.aggregative.PACC` quantifiers.
|
||||
|
||||
Equivalent to:
|
||||
|
||||
>>> ensembleFactory(learner, PACC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
>>> ensembleFactory(classifier, PACC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
See :meth:`ensembleFactory` for further details.
|
||||
|
||||
:param learner: sklearn's Estimator that generates a classifier
|
||||
:param classifier: sklearn's Estimator that generates a classifier
|
||||
:param param_grid: a dictionary with the grid of parameters to optimize for
|
||||
:param optim: a valid quantification or classification error, or a string name of it
|
||||
:param param_model_sel: a dictionary containing any keyworded argument to pass to
|
||||
|
@ -473,21 +473,21 @@ def EPACC(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
|||
:return: an instance of :class:`Ensemble`
|
||||
"""
|
||||
|
||||
return ensembleFactory(learner, PACC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
return ensembleFactory(classifier, PACC, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
|
||||
def EHDy(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
def EHDy(classifier, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
"""
|
||||
Implements an ensemble of :class:`quapy.method.aggregative.HDy` quantifiers, as used by
|
||||
`Pérez-Gállego et al., 2019 <https://www.sciencedirect.com/science/article/pii/S1566253517303652>`_.
|
||||
|
||||
Equivalent to:
|
||||
|
||||
>>> ensembleFactory(learner, HDy, param_grid, optim, param_mod_sel, **kwargs)
|
||||
>>> ensembleFactory(classifier, HDy, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
See :meth:`ensembleFactory` for further details.
|
||||
|
||||
:param learner: sklearn's Estimator that generates a classifier
|
||||
:param classifier: sklearn's Estimator that generates a classifier
|
||||
:param param_grid: a dictionary with the grid of parameters to optimize for
|
||||
:param optim: a valid quantification or classification error, or a string name of it
|
||||
:param param_model_sel: a dictionary containing any keyworded argument to pass to
|
||||
|
@ -496,20 +496,20 @@ def EHDy(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
|||
:return: an instance of :class:`Ensemble`
|
||||
"""
|
||||
|
||||
return ensembleFactory(learner, HDy, param_grid, optim, param_mod_sel, **kwargs)
|
||||
return ensembleFactory(classifier, HDy, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
|
||||
def EEMQ(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
def EEMQ(classifier, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||
"""
|
||||
Implements an ensemble of :class:`quapy.method.aggregative.EMQ` quantifiers.
|
||||
|
||||
Equivalent to:
|
||||
|
||||
>>> ensembleFactory(learner, EMQ, param_grid, optim, param_mod_sel, **kwargs)
|
||||
>>> ensembleFactory(classifier, EMQ, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
See :meth:`ensembleFactory` for further details.
|
||||
|
||||
:param learner: sklearn's Estimator that generates a classifier
|
||||
:param classifier: sklearn's Estimator that generates a classifier
|
||||
:param param_grid: a dictionary with the grid of parameters to optimize for
|
||||
:param optim: a valid quantification or classification error, or a string name of it
|
||||
:param param_model_sel: a dictionary containing any keyworded argument to pass to
|
||||
|
@ -518,4 +518,4 @@ def EEMQ(learner, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
|||
:return: an instance of :class:`Ensemble`
|
||||
"""
|
||||
|
||||
return ensembleFactory(learner, EMQ, param_grid, optim, param_mod_sel, **kwargs)
|
||||
return ensembleFactory(classifier, EMQ, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
|
|
@ -31,14 +31,14 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
>>>
|
||||
>>> # the text classifier is a CNN trained by NeuralClassifierTrainer
|
||||
>>> cnn = CNNnet(dataset.vocabulary_size, dataset.n_classes)
|
||||
>>> learner = NeuralClassifierTrainer(cnn, device='cuda')
|
||||
>>> classifier = NeuralClassifierTrainer(cnn, device='cuda')
|
||||
>>>
|
||||
>>> # train QuaNet (QuaNet is an alias to QuaNetTrainer)
|
||||
>>> model = QuaNet(learner, qp.environ['SAMPLE_SIZE'], device='cuda')
|
||||
>>> model = QuaNet(classifier, qp.environ['SAMPLE_SIZE'], device='cuda')
|
||||
>>> model.fit(dataset.training)
|
||||
>>> estim_prevalence = model.quantify(dataset.test.instances)
|
||||
|
||||
:param learner: an object implementing `fit` (i.e., that can be trained on labelled data),
|
||||
:param classifier: an object implementing `fit` (i.e., that can be trained on labelled data),
|
||||
`predict_proba` (i.e., that can generate posterior probabilities of unlabelled examples) and
|
||||
`transform` (i.e., that can generate embedded representations of the unlabelled instances).
|
||||
:param sample_size: integer, the sample size
|
||||
|
@ -60,7 +60,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
learner,
|
||||
classifier,
|
||||
sample_size,
|
||||
n_epochs=100,
|
||||
tr_iter_per_poch=500,
|
||||
|
@ -76,13 +76,13 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
checkpointname=None,
|
||||
device='cuda'):
|
||||
|
||||
assert hasattr(learner, 'transform'), \
|
||||
f'the learner {learner.__class__.__name__} does not seem to be able to produce document embeddings ' \
|
||||
assert hasattr(classifier, 'transform'), \
|
||||
f'the classifier {classifier.__class__.__name__} does not seem to be able to produce document embeddings ' \
|
||||
f'since it does not implement the method "transform"'
|
||||
assert hasattr(learner, 'predict_proba'), \
|
||||
f'the learner {learner.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
|
||||
assert hasattr(classifier, 'predict_proba'), \
|
||||
f'the classifier {classifier.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
|
||||
f'since it does not implement the method "predict_proba"'
|
||||
self.learner = learner
|
||||
self.classifier = classifier
|
||||
self.sample_size = sample_size
|
||||
self.n_epochs = n_epochs
|
||||
self.tr_iter = tr_iter_per_poch
|
||||
|
@ -105,26 +105,26 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
self.checkpoint = os.path.join(checkpointdir, checkpointname)
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.__check_params_colision(self.quanet_params, self.learner.get_params())
|
||||
self.__check_params_colision(self.quanet_params, self.classifier.get_params())
|
||||
self._classes_ = None
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
"""
|
||||
Trains QuaNet.
|
||||
|
||||
:param data: the training data on which to train QuaNet. If `fit_learner=True`, the data will be split in
|
||||
:param data: the training data on which to train QuaNet. If `fit_classifier=True`, the data will be split in
|
||||
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
|
||||
`fit_learner=False`, the data will be split in 66/34 for training QuaNet and validating it, respectively.
|
||||
:param fit_learner: if True, trains the classifier on a split containing 40% of the data
|
||||
`fit_classifier=False`, the data will be split in 66/34 for training QuaNet and validating it, respectively.
|
||||
:param fit_classifier: if True, trains the classifier on a split containing 40% of the data
|
||||
:return: self
|
||||
"""
|
||||
self._classes_ = data.classes_
|
||||
os.makedirs(self.checkpointdir, exist_ok=True)
|
||||
|
||||
if fit_learner:
|
||||
if fit_classifier:
|
||||
classifier_data, unused_data = data.split_stratified(0.4)
|
||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||
self.learner.fit(*classifier_data.Xy)
|
||||
self.classifier.fit(*classifier_data.Xy)
|
||||
else:
|
||||
classifier_data = None
|
||||
train_data, valid_data = data.split_stratified(0.66)
|
||||
|
@ -133,21 +133,21 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
self.tr_prev = data.prevalence()
|
||||
|
||||
# compute the posterior probabilities of the instances
|
||||
valid_posteriors = self.learner.predict_proba(valid_data.instances)
|
||||
train_posteriors = self.learner.predict_proba(train_data.instances)
|
||||
valid_posteriors = self.classifier.predict_proba(valid_data.instances)
|
||||
train_posteriors = self.classifier.predict_proba(train_data.instances)
|
||||
|
||||
# turn instances' original representations into embeddings
|
||||
valid_data_embed = LabelledCollection(self.learner.transform(valid_data.instances), valid_data.labels, self._classes_)
|
||||
train_data_embed = LabelledCollection(self.learner.transform(train_data.instances), train_data.labels, self._classes_)
|
||||
valid_data_embed = LabelledCollection(self.classifier.transform(valid_data.instances), valid_data.labels, self._classes_)
|
||||
train_data_embed = LabelledCollection(self.classifier.transform(train_data.instances), train_data.labels, self._classes_)
|
||||
|
||||
self.quantifiers = {
|
||||
'cc': CC(self.learner).fit(None, fit_learner=False),
|
||||
'acc': ACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
||||
'pcc': PCC(self.learner).fit(None, fit_learner=False),
|
||||
'pacc': PACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
||||
'cc': CC(self.classifier).fit(None, fit_classifier=False),
|
||||
'acc': ACC(self.classifier).fit(None, fit_classifier=False, val_split=valid_data),
|
||||
'pcc': PCC(self.classifier).fit(None, fit_classifier=False),
|
||||
'pacc': PACC(self.classifier).fit(None, fit_classifier=False, val_split=valid_data),
|
||||
}
|
||||
if classifier_data is not None:
|
||||
self.quantifiers['emq'] = EMQ(self.learner).fit(classifier_data, fit_learner=False)
|
||||
self.quantifiers['emq'] = EMQ(self.classifier).fit(classifier_data, fit_classifier=False)
|
||||
|
||||
self.status = {
|
||||
'tr-loss': -1,
|
||||
|
@ -199,8 +199,8 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
return prevs_estim
|
||||
|
||||
def quantify(self, instances):
|
||||
posteriors = self.learner.predict_proba(instances)
|
||||
embeddings = self.learner.transform(instances)
|
||||
posteriors = self.classifier.predict_proba(instances)
|
||||
embeddings = self.classifier.transform(instances)
|
||||
quant_estims = self._get_aggregative_estims(posteriors)
|
||||
self.quanet.eval()
|
||||
with torch.no_grad():
|
||||
|
@ -264,7 +264,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return {**self.learner.get_params(), **self.quanet_params}
|
||||
return {**self.classifier.get_params(), **self.quanet_params}
|
||||
|
||||
def set_params(self, **parameters):
|
||||
learner_params = {}
|
||||
|
@ -273,7 +273,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
self.quanet_params[key] = val
|
||||
else:
|
||||
learner_params[key] = val
|
||||
self.learner.set_params(**learner_params)
|
||||
self.classifier.set_params(**learner_params)
|
||||
|
||||
def __check_params_colision(self, quanet_params, learner_params):
|
||||
quanet_keys = set(quanet_params.keys())
|
||||
|
@ -281,7 +281,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
intersection = quanet_keys.intersection(learner_keys)
|
||||
if len(intersection) > 0:
|
||||
raise ValueError(f'the use of parameters {intersection} is ambiguous sine those can refer to '
|
||||
f'the parameters of QuaNet or the learner {self.learner.__class__.__name__}')
|
||||
f'the parameters of QuaNet or the learner {self.classifier.__class__.__name__}')
|
||||
|
||||
def clean_checkpoint(self):
|
||||
"""
|
||||
|
|
|
@ -88,7 +88,12 @@ class GridSearchQ(BaseQuantifier):
|
|||
|
||||
hyper = [dict({k: values[i] for i, k in enumerate(params_keys)}) for values in itertools.product(*params_values)]
|
||||
#pass a seed to parallel so it is set in clild processes
|
||||
scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), seed=qp.environ.get('_R_SEED', None), n_jobs=self.n_jobs)
|
||||
scores = qp.util.parallel(
|
||||
self._delayed_eval,
|
||||
((params, training) for params in hyper),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
for params, score, model in scores:
|
||||
if score is not None:
|
||||
|
@ -103,7 +108,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
tend = time()-tinit
|
||||
|
||||
if self.best_score_ is None:
|
||||
raise TimeoutError('all jobs took more than the timeout time to end')
|
||||
raise TimeoutError('no combination of hyperparameters seem to work')
|
||||
|
||||
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
|
||||
f'[took {tend:.4f}s]')
|
||||
|
@ -150,6 +155,13 @@ class GridSearchQ(BaseQuantifier):
|
|||
except TimeoutError:
|
||||
self._sout(f'timeout ({self.timeout}s) reached for config {params}')
|
||||
score = None
|
||||
except ValueError as e:
|
||||
self._sout(f'the combination of hyperparameters {params} is invalid')
|
||||
raise e
|
||||
except Exception as e:
|
||||
self._sout(f'something went wrong for config {params}; skipping:')
|
||||
self._sout(f'\tException: {e}')
|
||||
score = None
|
||||
|
||||
return params, score, model
|
||||
|
||||
|
|
Loading…
Reference in New Issue