refactoring no labelled collection and other improvements in EMQ
This commit is contained in:
parent
5738821d10
commit
960ca5076e
|
|
@ -1,10 +1,32 @@
|
|||
Change Log 0.1.10
|
||||
-----------------
|
||||
|
||||
- Base code Refactor:
|
||||
- Removing coupling between LabelledCollection and quantification methods. E.g.:
|
||||
def fit(data:LabelledCollection): -> def fit(X, y):
|
||||
- Adding function "predict" (function "quantify" is still present as an alias)
|
||||
- Aggregative methods's behavior in terms of fit_classifier and how to treat the val_split is now
|
||||
indicated exclusively at construction time, and it is no longer possible to indicate it at fit time.
|
||||
This is because, in v<=0.1.9, one could create a method (e.g., ACC) and then indicate:
|
||||
my_acc.fit(tr_data, fit_classifier=False, val_split=val_data)
|
||||
in which case the first argument is unused, and this was ambiguous with
|
||||
my_acc.fit(the_data, fit_classifier=False)
|
||||
in which case the_data is to be used for validation purposes. However, the val_split could be set as a fraction
|
||||
indicating only part of the_data must be used for validation, and the rest wasted... it was confusing.
|
||||
- EMQ has been modified, so that the representation function "classify" now only provides posterior
|
||||
probabilities and, if required, these are recalibrated (e.g., by "bcts") during the aggregation function.
|
||||
- A new parameter "on_calib_error" is passed to the constructor, which informs of the policy to follow
|
||||
in case the calibration functions failed. Options include:
|
||||
- 'raise': raises a RuntimeException (default)
|
||||
- 'backup': avoids calibration
|
||||
- Parameter "recalib" has been renamed "calib"
|
||||
- Added aggregative bootstrap for deriving confidence regions (confidence intervals, ellipses in the simplex, or
|
||||
ellipses in the CLR space). This method is efficient as it leverages the two-phases of the aggregative quantifiers.
|
||||
This method applies resampling only to the aggregation phase, thus avoiding to train many quantifiers, or
|
||||
classify multiple times the instances of a sample. See the new example no. 15.
|
||||
classify multiple times the instances of a sample. See:
|
||||
- quapy/method/confidence.py (new)
|
||||
- the new example no. 15.
|
||||
- BayesianCC moved to confidence.py, where methods having to do with confidence intervals live
|
||||
|
||||
|
||||
Change Log 0.1.9
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def quantification_models():
|
|||
yield 'MAX', MAX(newLR()), lr_params
|
||||
yield 'MS', MS(newLR()), lr_params
|
||||
yield 'MS2', MS2(newLR()), lr_params
|
||||
yield 'sldc', EMQ(newLR(), recalib='platt'), lr_params
|
||||
yield 'sldc', EMQ(newLR(), calib='platt'), lr_params
|
||||
yield 'svmmae', newSVMAE(), svmperf_params
|
||||
yield 'hdy', HDy(newLR()), lr_params
|
||||
|
||||
|
|
|
|||
|
|
@ -99,26 +99,27 @@ class KDEyML(AggregativeSoftQuantifier, KDEBase):
|
|||
|
||||
which corresponds to the maximum likelihood estimate.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier.
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a collection defining the specific set of data to use for validation.
|
||||
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||
on which the predictions are to be generated.
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
:param bandwidth: float, the bandwidth of the Kernel
|
||||
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=5, bandwidth=0.1, random_state=None):
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=0.1,
|
||||
random_state=None):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
|
||||
self.random_state=random_state
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth)
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
self.mix_densities = self.get_mixture_components(classif_predictions, labels, self.classes_, self.bandwidth)
|
||||
return self
|
||||
|
||||
def aggregate(self, posteriors: np.ndarray):
|
||||
|
|
@ -173,35 +174,35 @@ class KDEyHD(AggregativeSoftQuantifier, KDEBase):
|
|||
where the datapoints (trials) :math:`x_1,\\ldots,x_t\\sim_{\\mathrm{iid}} r` with :math:`r` the
|
||||
uniform distribution.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier.
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a collection defining the specific set of data to use for validation.
|
||||
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||
on which the predictions are to be generated.
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
:param bandwidth: float, the bandwidth of the Kernel
|
||||
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||
:param montecarlo_trials: number of Monte Carlo trials (default 10000)
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=5, divergence: str='HD',
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, divergence: str='HD',
|
||||
bandwidth=0.1, random_state=None, montecarlo_trials=10000):
|
||||
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.divergence = divergence
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
|
||||
self.random_state=random_state
|
||||
self.montecarlo_trials = montecarlo_trials
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth)
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
self.mix_densities = self.get_mixture_components(classif_predictions, labels, self.classes_, self.bandwidth)
|
||||
|
||||
N = self.montecarlo_trials
|
||||
rs = self.random_state
|
||||
n = data.n_classes
|
||||
n = len(self.classes_)
|
||||
self.reference_samples = np.vstack([kde_i.sample(N//n, random_state=rs) for kde_i in self.mix_densities])
|
||||
self.reference_classwise_densities = np.asarray([self.pdf(kde_j, self.reference_samples) for kde_j in self.mix_densities])
|
||||
self.reference_density = np.mean(self.reference_classwise_densities, axis=0) # equiv. to (uniform @ self.reference_classwise_densities)
|
||||
|
|
@ -265,20 +266,20 @@ class KDEyCS(AggregativeSoftQuantifier):
|
|||
|
||||
The authors showed that this distribution matching admits a closed-form solution
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier.
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a collection defining the specific set of data to use for validation.
|
||||
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||
on which the predictions are to be generated.
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
:param bandwidth: float, the bandwidth of the Kernel
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=5, bandwidth=0.1):
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=0.1):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
|
||||
|
||||
def gram_matrix_mix_sum(self, X, Y=None):
|
||||
|
|
@ -293,17 +294,17 @@ class KDEyCS(AggregativeSoftQuantifier):
|
|||
gram = norm_factor * rbf_kernel(X, Y, gamma=gamma)
|
||||
return gram.sum()
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
|
||||
P, y = classif_predictions.Xy
|
||||
n = data.n_classes
|
||||
P, y = classif_predictions, labels
|
||||
n = len(self.classes_)
|
||||
|
||||
assert all(sorted(np.unique(y)) == np.arange(n)), \
|
||||
'label name gaps not allowed in current implementation'
|
||||
|
||||
# counts_inv keeps track of the relative weight of each datapoint within its class
|
||||
# (i.e., the weight in its KDE model)
|
||||
counts_inv = 1 / (data.counts())
|
||||
counts_inv = 1 / (F.counts_from_labels(y, classes=self.classes_))
|
||||
|
||||
# tr_tr_sums corresponds to symbol \overline{B} in the paper
|
||||
tr_tr_sums = np.zeros(shape=(n,n), dtype=float)
|
||||
|
|
|
|||
|
|
@ -21,13 +21,13 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
Example:
|
||||
|
||||
>>> import quapy as qp
|
||||
>>> from quapy.method_name.meta import QuaNet
|
||||
>>> from quapy.method.meta import QuaNet
|
||||
>>> from quapy.classification.neural import NeuralClassifierTrainer, CNNnet
|
||||
>>>
|
||||
>>> # use samples of 100 elements
|
||||
>>> qp.environ['SAMPLE_SIZE'] = 100
|
||||
>>>
|
||||
>>> # load the kindle dataset as text, and convert words to numerical indexes
|
||||
>>> # load the Kindle dataset as text, and convert words to numerical indexes
|
||||
>>> dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
|
||||
>>> qp.train.preprocessing.index(dataset, min_df=5, inplace=True)
|
||||
>>>
|
||||
|
|
@ -37,12 +37,14 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
>>>
|
||||
>>> # train QuaNet (QuaNet is an alias to QuaNetTrainer)
|
||||
>>> model = QuaNet(classifier, qp.environ['SAMPLE_SIZE'], device='cuda')
|
||||
>>> model.fit(dataset.training)
|
||||
>>> model.fit(*dataset.training.Xy)
|
||||
>>> estim_prevalence = model.predict(dataset.test.instances)
|
||||
|
||||
:param classifier: an object implementing `fit` (i.e., that can be trained on labelled data),
|
||||
`predict_proba` (i.e., that can generate posterior probabilities of unlabelled examples) and
|
||||
`transform` (i.e., that can generate embedded representations of the unlabelled instances).
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param sample_size: integer, the sample size; default is None, meaning that the sample size should be
|
||||
taken from qp.environ["SAMPLE_SIZE"]
|
||||
:param n_epochs: integer, maximum number of training epochs
|
||||
|
|
@ -64,6 +66,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
|
||||
def __init__(self,
|
||||
classifier,
|
||||
fit_classifier=True,
|
||||
sample_size=None,
|
||||
n_epochs=100,
|
||||
tr_iter_per_poch=500,
|
||||
|
|
@ -86,6 +89,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
f'the classifier {classifier.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
|
||||
f'since it does not implement the method "predict_proba"'
|
||||
self.classifier = classifier
|
||||
self.fit_classifier = fit_classifier
|
||||
self.sample_size = qp._get_sample_size(sample_size)
|
||||
self.n_epochs = n_epochs
|
||||
self.tr_iter = tr_iter_per_poch
|
||||
|
|
@ -111,20 +115,21 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
self.__check_params_colision(self.quanet_params, self.classifier.get_params())
|
||||
self._classes_ = None
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
def fit(self, X, y):
|
||||
"""
|
||||
Trains QuaNet.
|
||||
|
||||
:param data: the training data on which to train QuaNet. If `fit_classifier=True`, the data will be split in
|
||||
:param X: the training instances on which to train QuaNet. If `fit_classifier=True`, the data will be split in
|
||||
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
|
||||
`fit_classifier=False`, the data will be split in 66/34 for training QuaNet and validating it, respectively.
|
||||
:param fit_classifier: if True, trains the classifier on a split containing 40% of the data
|
||||
:param y: the labels of X
|
||||
:return: self
|
||||
"""
|
||||
data = LabelledCollection(X, y)
|
||||
self._classes_ = data.classes_
|
||||
os.makedirs(self.checkpointdir, exist_ok=True)
|
||||
|
||||
if fit_classifier:
|
||||
if self.fit_classifier:
|
||||
classifier_data, unused_data = data.split_stratified(0.4)
|
||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||
self.classifier.fit(*classifier_data.Xy)
|
||||
|
|
|
|||
|
|
@ -18,18 +18,23 @@ class ThresholdOptimization(BinaryAggregativeQuantifier):
|
|||
that would allow for more true positives and many more false positives, on the grounds this
|
||||
would deliver larger denominators.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
|
||||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
||||
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
:param n_jobs: number of parallel workers
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=None, n_jobs=None):
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=None, n_jobs=None):
|
||||
super.__init__(classifier, fit_classifier, val_split)
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -115,8 +120,8 @@ class ThresholdOptimization(BinaryAggregativeQuantifier):
|
|||
return 0
|
||||
return FP / (FP + TN)
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
decision_scores, y = classif_predictions.Xy
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
decision_scores, y = classif_predictions, labels
|
||||
# the standard behavior is to keep the best threshold only
|
||||
self.tpr, self.fpr, self.threshold = self._eval_candidate_thresholds(decision_scores, y)[0]
|
||||
return self
|
||||
|
|
@ -134,17 +139,22 @@ class T50(ThresholdOptimization):
|
|||
for the threshold that makes `tpr` closest to 0.5.
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
|
||||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
||||
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=5):
|
||||
super().__init__(classifier, val_split)
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
|
||||
def condition(self, tpr, fpr) -> float:
|
||||
return abs(tpr - 0.5)
|
||||
|
|
@ -158,17 +168,20 @@ class MAX(ThresholdOptimization):
|
|||
for the threshold that maximizes `tpr-fpr`.
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
|
||||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
||||
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=5):
|
||||
super().__init__(classifier, val_split)
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
|
||||
def condition(self, tpr, fpr) -> float:
|
||||
# MAX strives to maximize (tpr - fpr), which is equivalent to minimize (fpr - tpr)
|
||||
|
|
@ -183,17 +196,20 @@ class X(ThresholdOptimization):
|
|||
for the threshold that yields `tpr=1-fpr`.
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
|
||||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
||||
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=5):
|
||||
super().__init__(classifier, val_split)
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
|
||||
def condition(self, tpr, fpr) -> float:
|
||||
return abs(1 - (tpr + fpr))
|
||||
|
|
@ -207,22 +223,25 @@ class MS(ThresholdOptimization):
|
|||
class prevalence estimates for all decision thresholds and returns the median of them all.
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
|
||||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
||||
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
"""
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=5):
|
||||
super().__init__(classifier, val_split)
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
|
||||
def condition(self, tpr, fpr) -> float:
|
||||
return 1
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
decision_scores, y = classif_predictions.Xy
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
decision_scores, y = classif_predictions, labels
|
||||
# keeps all candidates
|
||||
tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
|
||||
self.tprs = tprs_fprs_thresholds[:, 0]
|
||||
|
|
@ -246,16 +265,19 @@ class MS2(MS):
|
|||
which `tpr-fpr>0.25`
|
||||
The goal is to bring improved stability to the denominator of the adjustment.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
||||
misclassification rates are to be estimated.
|
||||
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
|
||||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
||||
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
"""
|
||||
def __init__(self, classifier: BaseEstimator=None, val_split=5):
|
||||
super().__init__(classifier, val_split)
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
|
||||
def discard(self, tpr, fpr) -> bool:
|
||||
return (tpr-fpr) <= 0.25
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from copy import deepcopy
|
|||
from typing import Callable, Literal, Union
|
||||
import numpy as np
|
||||
from abstention.calibration import NoBiasVectorScaling, TempScaling, VectorScaling
|
||||
from numpy.f2py.crackfortran import true_intent_list
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
from sklearn.exceptions import NotFittedError
|
||||
|
|
@ -36,6 +37,18 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
|
||||
The method :meth:`quantify` comes with a default implementation based on :meth:`classify`
|
||||
and :meth:`aggregate`.
|
||||
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple `(X,y)` defining the specific set of data to use for validation. Set to
|
||||
None when the method does not require any validation data, in order to avoid that some portion of
|
||||
the training data be wasted.
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: Union[None,BaseEstimator], fit_classifier:bool=True, val_split:Union[int,float,tuple,None]=5):
|
||||
|
|
@ -116,34 +129,13 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
Trains the aggregative quantifier. This comes down to training a classifier (if requested) and an
|
||||
aggregation function.
|
||||
|
||||
:param X: array-like, the training instances
|
||||
:param y: array-like, the labels
|
||||
:param X: array-like of shape `(n_samples, n_features)`, the training instances
|
||||
:param y: array-like of shape `(n_samples,)`, the labels
|
||||
:return: self
|
||||
"""
|
||||
self._check_init_parameters()
|
||||
classif_predictions = self.classifier_fit_predict(X, y)
|
||||
self.aggregation_fit(classif_predictions)
|
||||
return self
|
||||
|
||||
def fit_depr(self, data: LabelledCollection, fit_classifier=True, val_split=None):
|
||||
"""
|
||||
Trains the aggregative quantifier. This comes down to training a classifier and an aggregation function.
|
||||
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a collection defining the specific set of data to use for validation.
|
||||
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||
on which the predictions are to be generated.
|
||||
:return: self
|
||||
"""
|
||||
self._check_init_parameters()
|
||||
classif_predictions = self.classifier_fit_predict_depr(data, fit_classifier, predict_on=val_split)
|
||||
self.aggregation_fit_depr(classif_predictions, data)
|
||||
classif_predictions, labels = self.classifier_fit_predict(X, y)
|
||||
self.aggregation_fit(classif_predictions, labels)
|
||||
return self
|
||||
|
||||
def classifier_fit_predict(self, X, y):
|
||||
|
|
@ -151,19 +143,20 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
Trains the classifier if requested (`fit_classifier=True`) and generate the necessary predictions to
|
||||
train the aggregation function.
|
||||
|
||||
:param X: array-like, the training instances
|
||||
:param y: array-like, the labels
|
||||
:param X: array-like of shape `(n_samples, n_features)`, the training instances
|
||||
:param y: array-like of shape `(n_samples,)`, the labels
|
||||
"""
|
||||
self._check_classifier()
|
||||
|
||||
# self._check_non_empty_classes(y)
|
||||
|
||||
predictions, labels = None, None
|
||||
if isinstance(self.val_split, int):
|
||||
assert self.fit_classifier, f'unexpected value for {self.fit_classifier=}'
|
||||
num_folds = self.val_split
|
||||
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
|
||||
predictions = cross_val_predict(self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method())
|
||||
yval = y
|
||||
labels = y
|
||||
self.classifier.fit(X, y)
|
||||
elif isinstance(self.val_split, float):
|
||||
assert self.fit_classifier, f'unexpected value for {self.fit_classifier=}'
|
||||
|
|
@ -171,26 +164,30 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
Xtr, Xval, ytr, yval = train_test_split(X, y, train_size=train_prop, stratify=y)
|
||||
self.classifier.fit(Xtr, ytr)
|
||||
predictions = self.classify(Xval)
|
||||
labels = yval
|
||||
elif isinstance(self.val_split, tuple):
|
||||
Xval, yval = self.val_split
|
||||
if self.fit_classifier:
|
||||
self.classifier.fit(X, y)
|
||||
predictions = self.classify(Xval)
|
||||
labels = yval
|
||||
elif self.val_split is None:
|
||||
if self.fit_classifier:
|
||||
self.classifier.fit(X, y)
|
||||
predictions, yval = None, None
|
||||
predictions, labels = None, None
|
||||
else:
|
||||
raise ValueError(f'unexpected type for {self.val_split=}')
|
||||
|
||||
return predictions, yval
|
||||
return predictions, labels
|
||||
|
||||
@abstractmethod
|
||||
def aggregation_fit(self, classif_predictions, **kwargs):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Trains the aggregation function.
|
||||
|
||||
:param classif_predictions: the classification predictions; whatever the method
|
||||
:meth:`classify` returns
|
||||
:param classif_predictions: array-like with the classification predictions
|
||||
(whatever the method :meth:`classify` returns)
|
||||
:param labels: array-like with the true labels associated to each classifier prediction
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -218,8 +215,8 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
:meth:`aggregate`, e.g., posterior probabilities for probabilistic quantifiers, or crisp predictions for
|
||||
non-probabilistic quantifiers. The default one is "decision_function".
|
||||
|
||||
:param X: array-like of shape `(n_instances, n_features,)`
|
||||
:return: np.ndarray of shape `(n_instances,)` with label predictions
|
||||
:param X: array-like of shape `(n_samples, n_features)`, the data instances
|
||||
:return: np.ndarray of shape `(n_instances,)` with classifier predictions
|
||||
"""
|
||||
return getattr(self.classifier, self._classifier_method())(X)
|
||||
|
||||
|
|
@ -236,7 +233,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
Guarantees that the underlying classifier implements the method required for issuing predictions, i.e.,
|
||||
the method indicated by the :meth:`_classifier_method`
|
||||
|
||||
:param adapt_if_necessary: unused unless overriden
|
||||
:param adapt_if_necessary: unused unless overridden
|
||||
"""
|
||||
assert hasattr(self.classifier, self._classifier_method()), \
|
||||
f"the method does not implement the required {self._classifier_method()} method"
|
||||
|
|
@ -246,7 +243,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
Generate class prevalence estimates for the sample's instances by aggregating the label predictions generated
|
||||
by the classifier.
|
||||
|
||||
:param X: array-like
|
||||
:param X: array-like of shape `(n_samples, n_features)`, the data instances
|
||||
:return: `np.ndarray` of shape `(n_classes)` with class prevalence estimates.
|
||||
"""
|
||||
classif_predictions = self.classify(X)
|
||||
|
|
@ -257,7 +254,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
"""
|
||||
Implements the aggregation of the classifier predictions.
|
||||
|
||||
:param classif_predictions: `np.ndarray` of label predictions
|
||||
:param classif_predictions: `np.ndarray` of classifier predictions
|
||||
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
|
||||
"""
|
||||
...
|
||||
|
|
@ -268,7 +265,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
Class labels, in the same order in which class prevalence values are to be computed.
|
||||
This default implementation actually returns the class labels of the learner.
|
||||
|
||||
:return: array-like
|
||||
:return: array-like, the class labels
|
||||
"""
|
||||
return self.classifier.classes_
|
||||
|
||||
|
|
@ -356,11 +353,12 @@ class CC(AggregativeCrispQuantifier):
|
|||
def __init__(self, classifier: BaseEstimator = None, fit_classifier: bool = True):
|
||||
super().__init__(classifier, fit_classifier, val_split=None)
|
||||
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: not used
|
||||
:param classif_predictions: unused
|
||||
:param labels: unused
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
@ -368,7 +366,7 @@ class CC(AggregativeCrispQuantifier):
|
|||
"""
|
||||
Computes class prevalence estimates by counting the prevalence of each of the predicted labels.
|
||||
|
||||
:param classif_predictions: array-like with label predictions
|
||||
:param classif_predictions: array-like with classifier predictions
|
||||
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
|
||||
"""
|
||||
return F.prevalence_from_labels(classif_predictions, self.classes_)
|
||||
|
|
@ -385,11 +383,12 @@ class PCC(AggregativeSoftQuantifier):
|
|||
def __init__(self, classifier: BaseEstimator = None, fit_classifier: bool = True):
|
||||
super().__init__(classifier, fit_classifier, val_split=None)
|
||||
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Nothing to do here!
|
||||
|
||||
:param classif_predictions: not used
|
||||
:param classif_predictions: unused
|
||||
:param labels: unused
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
@ -403,15 +402,17 @@ class ACC(AggregativeCrispQuantifier):
|
|||
the "adjusted" variant of :class:`CC`, that corrects the predictions of CC
|
||||
according to the `misclassification rates`.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a collection defining the specific set of data to use for validation.
|
||||
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||
on which the predictions are to be generated.
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
:param str method: adjustment method to be used:
|
||||
|
||||
|
|
@ -470,16 +471,20 @@ class ACC(AggregativeCrispQuantifier):
|
|||
`Vaz et al. 2018 <https://jmlr.org/papers/v20/18-456.html>`_. This amounts
|
||||
to setting method to 'invariant-ratio' and clipping to 'project'.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param fit_classifier: bool, whether to fit the classifier or not
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a collection defining the specific set of data to use for validation.
|
||||
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||
on which the predictions are to be generated.
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
:param n_jobs: number of parallel workers
|
||||
|
||||
:return: an instance of ACC configured so that it implements the Invariant Ratio Estimator
|
||||
"""
|
||||
return ACC(classifier, fit_classifier=fit_classifier, val_split=val_split, method='invariant-ratio', norm='mapsimplex', n_jobs=n_jobs)
|
||||
|
|
@ -492,15 +497,14 @@ class ACC(AggregativeCrispQuantifier):
|
|||
if self.norm not in ACC.NORMALIZATIONS:
|
||||
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
|
||||
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Estimates the misclassification rates.
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the label predictions issued by the classifier and, as labels, the true labels
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param classif_predictions: array-like with the predicted labels
|
||||
:param labels: array-like with the true labels associated to each predicted label
|
||||
"""
|
||||
pred_labels, true_labels = classif_predictions
|
||||
true_labels = labels
|
||||
pred_labels = classif_predictions
|
||||
self.cc = CC(self.classifier, fit_classifier=False)
|
||||
self.Pte_cond_estim_ = ACC.getPteCondEstim(self.classifier.classes_, true_labels, pred_labels)
|
||||
|
||||
|
|
@ -541,16 +545,17 @@ class PACC(AggregativeSoftQuantifier):
|
|||
`Probabilistic Adjusted Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
|
||||
the probabilistic variant of ACC that relies on the posterior probabilities returned by a probabilistic classifier.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: bool, whether to fit the classifier or not
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`). Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||
on which the predictions are to be generated.
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
:param str method: adjustment method to be used:
|
||||
|
||||
|
|
@ -606,15 +611,15 @@ class PACC(AggregativeSoftQuantifier):
|
|||
if self.norm not in ACC.NORMALIZATIONS:
|
||||
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
|
||||
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Estimates the misclassification rates
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param classif_predictions: array-like with posterior probabilities
|
||||
:param labels: array-like with the true labels associated to each vector of posterior probabilities
|
||||
"""
|
||||
posteriors, true_labels = classif_predictions
|
||||
posteriors = classif_predictions
|
||||
true_labels = labels
|
||||
self.pcc = PCC(self.classifier, fit_classifier=False)
|
||||
self.Pte_cond_estim_ = PACC.getPteCondEstim(self.classifier.classes_, true_labels, posteriors)
|
||||
|
||||
|
|
@ -656,61 +661,98 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
prevalence, an estimate of it obtained via k-fold cross validation (instead of the true training prevalence),
|
||||
and to recalibrate the posterior probabilities of the classifier.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param fit_classifier: bool, whether to fit the classifier or not
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer, indicating that the predictions
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`, default 5); or as a collection defining the specific set of data to use for validation.
|
||||
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||
on which the predictions are to be generated. This hyperparameter is only meant to be used when the
|
||||
heuristics are to be applied, i.e., if a recalibration is required. The default value is None (meaning
|
||||
the recalibration is not required). In case this hyperparameter is set to a value other than None, but
|
||||
the recalibration is not required (recalib=None), a warning message will be raised.
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
This hyperparameter is only meant to be used when the heuristics are to be applied, i.e., if a
|
||||
calibration is required. The default value is None (meaning the calibration is not required). In
|
||||
case this hyperparameter is set to a value other than None, but the calibration is not required
|
||||
(calib=None), a warning message will be raised.
|
||||
|
||||
:param exact_train_prev: set to True (default) for using the true training prevalence as the initial observation;
|
||||
set to False for computing the training prevalence as an estimate of it, i.e., as the expected
|
||||
value of the posterior probabilities of the training instances.
|
||||
:param recalib: a string indicating the method of recalibration.
|
||||
|
||||
:param calib: a string indicating the method of calibration.
|
||||
Available choices include "nbvs" (No-Bias Vector Scaling), "bcts" (Bias-Corrected Temperature Scaling,
|
||||
default), "ts" (Temperature Scaling), and "vs" (Vector Scaling). Default is None (no recalibration).
|
||||
default), "ts" (Temperature Scaling), and "vs" (Vector Scaling). Default is None (no calibration).
|
||||
|
||||
:param on_calib_error: a string indicating the policy to follow in case the calibrator fails at runtime.
|
||||
Options include "raise" (default), in which case a RuntimeException is raised; and "backup", in which
|
||||
case the calibrator is silently skipped.
|
||||
|
||||
:param n_jobs: number of parallel workers. Only used for recalibrating the classifier if `val_split` is set to
|
||||
an integer `k` --the number of folds.
|
||||
"""
|
||||
|
||||
MAX_ITER = 1000
|
||||
EPSILON = 1e-4
|
||||
ON_CALIB_ERROR_VALUES = ['raise', 'backup']
|
||||
CALIB_OPTIONS = [None, 'nbvs', 'bcts', 'ts', 'vs']
|
||||
|
||||
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=None, exact_train_prev=True,
|
||||
calib=None, on_calib_error='raise', n_jobs=None):
|
||||
|
||||
assert calib in EMQ.CALIB_OPTIONS, \
|
||||
f'invalid value for {calib=}; valid ones are {EMQ.CALIB_OPTIONS}'
|
||||
assert on_calib_error in EMQ.ON_CALIB_ERROR_VALUES, \
|
||||
f'invalid value for {on_calib_error=}; valid ones are {EMQ.ON_CALIB_ERROR_VALUES}'
|
||||
|
||||
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=None, exact_train_prev=True, recalib=None,
|
||||
n_jobs=None):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.exact_train_prev = exact_train_prev
|
||||
self.recalib = recalib
|
||||
self.calib = calib
|
||||
self.on_calib_errors = on_calib_error
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
@classmethod
|
||||
def EMQ_BCTS(cls, classifier: BaseEstimator, fit_classifier=True, val_split=5, n_jobs=None):
|
||||
def EMQ_BCTS(cls, classifier: BaseEstimator, fit_classifier=True, val_split=5, on_calib_error="raise", n_jobs=None):
|
||||
"""
|
||||
Constructs an instance of EMQ using the best configuration found in the `Alexandari et al. paper
|
||||
<http://proceedings.mlr.press/v119/alexandari20a.html>`_, i.e., one that relies on Bias-Corrected Temperature
|
||||
Scaling (BCTS) as a recalibration function, and that uses an estimate of the training prevalence instead of
|
||||
Scaling (BCTS) as a calibration function, and that uses an estimate of the training prevalence instead of
|
||||
the true training prevalence.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param n_jobs: number of parallel workers.
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
:param on_calib_error: a string indicating the policy to follow in case the calibrator fails at runtime.
|
||||
Options include "raise" (default), in which case a RuntimeException is raised; and "backup", in which
|
||||
case the calibrator is silently skipped.
|
||||
|
||||
:param n_jobs: number of parallel workers. Only used for recalibrating the classifier if `val_split` is set to
|
||||
an integer `k` --the number of folds.
|
||||
|
||||
:return: An instance of EMQ with BCTS
|
||||
"""
|
||||
return EMQ(classifier, fit_classifier=fit_classifier, val_split=val_split, exact_train_prev=False, recalib='bcts', n_jobs=n_jobs)
|
||||
return EMQ(classifier, fit_classifier=fit_classifier, val_split=val_split, exact_train_prev=False,
|
||||
calib='bcts', on_calib_error=on_calib_error, n_jobs=n_jobs)
|
||||
|
||||
def _check_init_parameters(self):
|
||||
if self.val_split is not None:
|
||||
if self.exact_train_prev and self.recalib is None:
|
||||
if self.exact_train_prev and self.calib is None:
|
||||
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
|
||||
f'{self.exact_train_prev=} and {self.recalib=}. This has no effect and causes an unnecessary '
|
||||
f'{self.exact_train_prev=} and {self.calib=}. This has no effect and causes an unnecessary '
|
||||
f'overload.')
|
||||
else:
|
||||
if self.recalib is not None:
|
||||
print(f'[warning] The parameter {self.recalib=} requires the val_split be different from None. '
|
||||
if self.calib is not None:
|
||||
print(f'[warning] The parameter {self.calib=} requires the val_split be different from None. '
|
||||
f'This parameter will be set to 5. To avoid this warning, set this value to a float value '
|
||||
f'indicating the proportion of training data to be used as validation, or to an integer '
|
||||
f'indicating the number of folds for kFCV.')
|
||||
|
|
@ -718,56 +760,78 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
|
||||
def classify(self, X):
|
||||
"""
|
||||
Provides the posterior probabilities for the given instances. If the classifier was required
|
||||
to be recalibrated, then these posteriors are recalibrated accordingly.
|
||||
Provides the posterior probabilities for the given instances. The calibration function, if required,
|
||||
has no effect in this step, and is only involved in the aggregate method.
|
||||
|
||||
:param X: array-like of shape `(n_instances, n_dimensions,)`
|
||||
:return: np.ndarray of shape `(n_instances, n_classes,)` with posterior probabilities
|
||||
"""
|
||||
posteriors = self.classifier.predict_proba(X)
|
||||
if hasattr(self, 'calibration_function') and self.calibration_function is not None:
|
||||
posteriors = self.calibration_function(posteriors)
|
||||
return posteriors
|
||||
return self.classifier.predict_proba(X)
|
||||
|
||||
def classifier_fit_predict(self, X, y):
|
||||
classif_predictions = super().classifier_fit_predict(X, y)
|
||||
self.train_prevalence = F.prevalence_from_labels(y, classes=self.classes_)
|
||||
return classif_predictions
|
||||
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
def _fit_calibration(self, calibrator, P, y):
|
||||
n_classes = len(self.classes_)
|
||||
|
||||
if not np.issubdtype(y.dtype, np.number):
|
||||
y = np.searchsorted(self.classes_, y)
|
||||
|
||||
try:
|
||||
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
|
||||
except Exception as e:
|
||||
if self.on_calib_errors == 'raise':
|
||||
raise RuntimeError(f'calibration {self.calib} failed at fit time: {e}')
|
||||
elif self.on_calib_errors == 'backup':
|
||||
self.calibration_function = lambda P: P
|
||||
|
||||
def _calibrate_if_requested(self, uncalib_posteriors):
|
||||
if hasattr(self, 'calibration_function') and self.calibration_function is not None:
|
||||
try:
|
||||
calib_posteriors = self.calibration_function(uncalib_posteriors)
|
||||
except Exception as e:
|
||||
if self.on_calib_errors == 'raise':
|
||||
raise RuntimeError(f'calibration {self.calib} failed at predict time: {e}')
|
||||
elif self.on_calib_errors == 'backup':
|
||||
calib_posteriors = uncalib_posteriors
|
||||
else:
|
||||
raise ValueError(f'unexpected {self.on_calib_errors=}; '
|
||||
f'valid options are {EMQ.ON_CALIB_ERROR_VALUES}')
|
||||
return calib_posteriors
|
||||
return uncalib_posteriors
|
||||
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
|
||||
ir requested.
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||
:param classif_predictions: array-like with the raw (i.e., uncalibrated) posterior probabilities
|
||||
returned by the classifier
|
||||
:param labels: array-like with the true labels associated to each classifier prediction
|
||||
"""
|
||||
P, y = classif_predictions
|
||||
n_classes = len(self.classes_)
|
||||
if self.recalib is not None:
|
||||
if self.recalib == 'nbvs':
|
||||
calibrator = NoBiasVectorScaling()
|
||||
elif self.recalib == 'bcts':
|
||||
calibrator = TempScaling(bias_positions='all')
|
||||
elif self.recalib == 'ts':
|
||||
calibrator = TempScaling()
|
||||
elif self.recalib == 'vs':
|
||||
calibrator = VectorScaling()
|
||||
else:
|
||||
raise ValueError('invalid param argument for recalibration method; available ones are '
|
||||
'"nbvs", "bcts", "ts", and "vs".')
|
||||
P = classif_predictions
|
||||
y = labels
|
||||
if self.calib is not None:
|
||||
calibrator = {
|
||||
'nbvs': NoBiasVectorScaling(),
|
||||
'bcts': TempScaling(bias_positions='all'),
|
||||
'ts': TempScaling(),
|
||||
'vs': VectorScaling()
|
||||
}.get(self.calib, None)
|
||||
|
||||
if not np.issubdtype(y.dtype, np.number):
|
||||
y = np.searchsorted(self.classes_, y)
|
||||
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
|
||||
if calibrator is None:
|
||||
raise ValueError(f'invalid value for {self.calib=}; valid ones are {EMQ.CALIB_OPTIONS}')
|
||||
|
||||
self._fit_calibration(calibrator, P, y)
|
||||
|
||||
if not self.exact_train_prev:
|
||||
train_posteriors = classif_predictions.X
|
||||
if self.recalib is not None:
|
||||
train_posteriors = self.calibration_function(train_posteriors)
|
||||
self.train_prevalence = F.prevalence_from_probabilities(train_posteriors)
|
||||
P = self._calibrate_if_requested(P)
|
||||
self.train_prevalence = F.prevalence_from_probabilities(P)
|
||||
|
||||
def aggregate(self, classif_posteriors, epsilon=EPSILON):
|
||||
classif_posteriors = self._calibrate_if_requested(classif_posteriors)
|
||||
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||
return priors
|
||||
|
||||
|
|
@ -780,6 +844,7 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
:return: np.ndarray of shape `(n_instances, n_classes)`
|
||||
"""
|
||||
classif_posteriors = self.classify(instances)
|
||||
classif_posteriors = self._calibrate_if_requested(classif_posteriors)
|
||||
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||
return posteriors
|
||||
|
||||
|
|
@ -827,101 +892,6 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
return qs, ps
|
||||
|
||||
|
||||
class BayesianCC(AggregativeCrispQuantifier):
|
||||
"""
|
||||
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ method,
|
||||
which is a variant of :class:`ACC` that calculates the posterior probability distribution
|
||||
over the prevalence vectors, rather than providing a point estimate obtained
|
||||
by matrix inversion.
|
||||
|
||||
Can be used to diagnose degeneracy in the predictions visible when the confusion
|
||||
matrix has high condition number or to quantify uncertainty around the point estimate.
|
||||
|
||||
This method relies on extra dependencies, which have to be installed via:
|
||||
`$ pip install quapy[bayes]`
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: a float in (0, 1) indicating the proportion of the training data to be used,
|
||||
as a stratified held-out validation set, for generating classifier predictions.
|
||||
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
|
||||
:param num_samples: number of samples to draw from the posterior (default 1000)
|
||||
:param mcmc_seed: random seed for the MCMC sampler (default 0)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
classifier: BaseEstimator = None,
|
||||
val_split: float = 0.75,
|
||||
num_warmup: int = 500,
|
||||
num_samples: int = 1_000,
|
||||
mcmc_seed: int = 0):
|
||||
|
||||
if num_warmup <= 0:
|
||||
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
|
||||
if num_samples <= 0:
|
||||
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
||||
|
||||
if (not isinstance(val_split, float)) or val_split <= 0 or val_split >= 1:
|
||||
raise ValueError(f'val_split must be a float in (0, 1), got {val_split}')
|
||||
|
||||
if _bayesian.DEPENDENCIES_INSTALLED is False:
|
||||
raise ImportError("Auxiliary dependencies are required. Run `$ pip install quapy[bayes]` to install them.")
|
||||
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
self.num_warmup = num_warmup
|
||||
self.num_samples = num_samples
|
||||
self.mcmc_seed = mcmc_seed
|
||||
|
||||
# Array of shape (n_classes, n_predicted_classes,) where entry (y, c) is the number of instances
|
||||
# labeled as class y and predicted as class c.
|
||||
# By default, this array is set to None and later defined as part of the `aggregation_fit` phase
|
||||
self._n_and_c_labeled = None
|
||||
|
||||
# Dictionary with posterior samples, set when `aggregate` is provided.
|
||||
self._samples = None
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
"""
|
||||
Estimates the misclassification rates.
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the label predictions issued by the classifier and, as labels, the true labels
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
"""
|
||||
pred_labels, true_labels = classif_predictions.Xy
|
||||
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels,
|
||||
labels=self.classifier.classes_)
|
||||
|
||||
def sample_from_posterior(self, classif_predictions):
|
||||
if self._n_and_c_labeled is None:
|
||||
raise ValueError("aggregation_fit must be called before sample_from_posterior")
|
||||
|
||||
n_c_unlabeled = F.counts_from_labels(classif_predictions, self.classifier.classes_)
|
||||
|
||||
self._samples = _bayesian.sample_posterior(
|
||||
n_c_unlabeled=n_c_unlabeled,
|
||||
n_y_and_c_labeled=self._n_and_c_labeled,
|
||||
num_warmup=self.num_warmup,
|
||||
num_samples=self.num_samples,
|
||||
seed=self.mcmc_seed,
|
||||
)
|
||||
return self._samples
|
||||
|
||||
def get_prevalence_samples(self):
|
||||
if self._samples is None:
|
||||
raise ValueError("sample_from_posterior must be called before get_prevalence_samples")
|
||||
return self._samples[_bayesian.P_TEST_Y]
|
||||
|
||||
def get_conditional_probability_samples(self):
|
||||
if self._samples is None:
|
||||
raise ValueError("sample_from_posterior must be called before get_conditional_probability_samples")
|
||||
return self._samples[_bayesian.P_C_COND_Y]
|
||||
|
||||
def aggregate(self, classif_predictions):
|
||||
samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
|
||||
return np.asarray(samples.mean(axis=0), dtype=float)
|
||||
|
||||
|
||||
class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||
"""
|
||||
`Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy).
|
||||
|
|
@ -932,24 +902,30 @@ class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
class-conditional distributions of the posterior probabilities returned for the positive and negative validation
|
||||
examples, respectively. The parameters of the mixture thus represent the estimates of the class prevalence values.
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier
|
||||
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
|
||||
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5)..
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator = None, val_split=5):
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=5):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Trains the aggregation function of HDy.
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param classif_predictions: array-like with the posterior probabilities returned by the classifier
|
||||
:param labels: array-like with the true labels associated to each posterior
|
||||
"""
|
||||
P, y = classif_predictions.Xy
|
||||
P, y = classif_predictions, labels
|
||||
Px = P[:, self.pos_label] # takes only the P(y=+1|x)
|
||||
self.Pxy1 = Px[y == self.pos_label]
|
||||
self.Pxy0 = Px[y == self.neg_label]
|
||||
|
|
@ -1003,20 +979,31 @@ class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
minimizes the distance between distributions.
|
||||
Details for the ternary search have been got from <https://dl.acm.org/doi/pdf/10.1145/3219819.3220059>
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier
|
||||
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
|
||||
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5)..
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
:param n_bins: an int with the number of bins to use to compute the histograms.
|
||||
|
||||
:param divergence: a str indicating the name of divergence (currently supported ones are "HD" or "topsoe"), or a
|
||||
callable function computes the divergence between two distributions (two equally sized arrays).
|
||||
|
||||
:param tol: a float with the tolerance for the ternary search algorithm.
|
||||
|
||||
:param n_jobs: number of parallel workers.
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator = None, val_split=5, n_bins=8, divergence: Union[str, Callable] = 'HD',
|
||||
tol=1e-05, n_jobs=None):
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=5, n_bins=8,
|
||||
divergence: Union[str, Callable] = 'HD', tol=1e-05, n_jobs=None):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.tol = tol
|
||||
self.divergence = divergence
|
||||
self.n_bins = n_bins
|
||||
|
|
@ -1038,15 +1025,14 @@ class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
# Left and right are the current bounds; the maximum is between them
|
||||
return (left + right) / 2
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Trains the aggregation function of DyS.
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param classif_predictions: array-like with the posterior probabilities returned by the classifier
|
||||
:param labels: array-like with the true labels associated to each posterior
|
||||
"""
|
||||
Px, y = classif_predictions.Xy
|
||||
Px, y = classif_predictions, labels
|
||||
Px = Px[:, self.pos_label] # takes only the P(y=+1|x)
|
||||
self.Pxy1 = Px[y == self.pos_label]
|
||||
self.Pxy0 = Px[y == self.neg_label]
|
||||
|
|
@ -1074,24 +1060,30 @@ class SMM(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
SMM is a simplification of matching distribution methods where the representation of the examples
|
||||
is created using the mean instead of a histogram (conceptually equivalent to PACC).
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a binary classifier.
|
||||
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
|
||||
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5)..
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator = None, val_split=5):
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=5):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Trains the aggregation function of SMM.
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param classif_predictions: array-like with the posterior probabilities returned by the classifier
|
||||
:param labels: array-like with the true labels associated to each posterior
|
||||
"""
|
||||
Px, y = classif_predictions.Xy
|
||||
Px, y = classif_predictions, labels
|
||||
Px = Px[:, self.pos_label] # takes only the P(y=+1|x)
|
||||
self.Pxy1 = Px[y == self.pos_label]
|
||||
self.Pxy0 = Px[y == self.neg_label]
|
||||
|
|
@ -1113,25 +1105,32 @@ class DMy(AggregativeSoftQuantifier):
|
|||
probabilities. This implementation takes the number of bins, the divergence, and the possibility to work on CDF
|
||||
as hyperparameters.
|
||||
|
||||
:param classifier: a `sklearn`'s Estimator that generates a probabilistic classifier
|
||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set to model the
|
||||
validation distribution.
|
||||
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
|
||||
validation data, or as an integer, indicating that the validation distribution should be estimated via
|
||||
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
|
||||
:param fit_classifier: whether to train the learner (default is True). Set to False if the
|
||||
learner has been trained outside the quantifier.
|
||||
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
|
||||
:param nbins: number of bins used to discretize the distributions (default 8)
|
||||
|
||||
:param divergence: a string representing a divergence measure (currently, "HD" and "topsoe" are implemented)
|
||||
or a callable function taking two ndarrays of the same dimension as input (default "HD", meaning Hellinger
|
||||
Distance)
|
||||
|
||||
:param cdf: whether to use CDF instead of PDF (default False)
|
||||
|
||||
:param n_jobs: number of parallel workers (default None)
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator = None, val_split=5, nbins=8, divergence: Union[str, Callable] = 'HD',
|
||||
cdf=False, search='optim_minimize', n_jobs=None):
|
||||
self.classifier = qp._get_classifier(classifier)
|
||||
self.val_split = val_split
|
||||
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=5, nbins=8,
|
||||
divergence: Union[str, Callable] = 'HD', cdf=False, search='optim_minimize', n_jobs=None):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.nbins = nbins
|
||||
self.divergence = divergence
|
||||
self.cdf = cdf
|
||||
|
|
@ -1162,7 +1161,7 @@ class DMy(AggregativeSoftQuantifier):
|
|||
distributions = np.cumsum(distributions, axis=1)
|
||||
return distributions
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Trains the aggregation function of a distribution matching method. This comes down to generating the
|
||||
validation distributions out of the training data.
|
||||
|
|
@ -1172,11 +1171,10 @@ class DMy(AggregativeSoftQuantifier):
|
|||
distribution of posterior probabilities `P(Y=j|X=x)` for training data labelled with class `i`, and `dij[k]`
|
||||
is the fraction of instances with a value in the `k`-th bin.
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param classif_predictions: array-like with the posterior probabilities returned by the classifier
|
||||
:param labels: array-like with the true labels associated to each posterior
|
||||
"""
|
||||
posteriors, true_labels = classif_predictions.Xy
|
||||
posteriors, true_labels = classif_predictions, labels
|
||||
n_classes = len(self.classifier.classes_)
|
||||
|
||||
self.validation_distribution = qp.util.parallel(
|
||||
|
|
@ -1432,7 +1430,7 @@ class AggregativeMedianEstimator(BinaryQuantifier):
|
|||
:param base_quantifier: the base, binary quantifier
|
||||
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||
:param param_grid: the grid or parameters towards which the median will be computed
|
||||
:param n_jobs: number of parllel workes
|
||||
:param n_jobs: number of parallel workers
|
||||
"""
|
||||
|
||||
def __init__(self, base_quantifier: AggregativeQuantifier, param_grid: dict, random_state=None, n_jobs=None):
|
||||
|
|
@ -1449,32 +1447,32 @@ class AggregativeMedianEstimator(BinaryQuantifier):
|
|||
|
||||
def _delayed_fit(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
params, training = args
|
||||
params, X, y = args
|
||||
model = deepcopy(self.base_quantifier)
|
||||
model.set_params(**params)
|
||||
model.fit(training)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
def _delayed_fit_classifier(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
cls_params, training, kwargs = args
|
||||
cls_params, X, y = args
|
||||
model = deepcopy(self.base_quantifier)
|
||||
model.set_params(**cls_params)
|
||||
predictions = model.classifier_fit_predict(training, **kwargs)
|
||||
return (model, predictions)
|
||||
predictions, labels = model.classifier_fit_predict(X, y)
|
||||
return (model, predictions, labels)
|
||||
|
||||
def _delayed_fit_aggregation(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
((model, predictions), q_params), training = args
|
||||
((model, predictions, y), q_params) = args
|
||||
model = deepcopy(model)
|
||||
model.set_params(**q_params)
|
||||
model.aggregation_fit(predictions, training)
|
||||
model.aggregation_fit(predictions, y)
|
||||
return model
|
||||
|
||||
def fit(self, training: LabelledCollection, **kwargs):
|
||||
def fit(self, X, y):
|
||||
import itertools
|
||||
|
||||
self._check_binary(training, self.__class__.__name__)
|
||||
self._check_binary(y, self.__class__.__name__)
|
||||
|
||||
if isinstance(self.base_quantifier, AggregativeQuantifier):
|
||||
cls_configs, q_configs = qp.model_selection.group_params(self.param_grid)
|
||||
|
|
@ -1482,7 +1480,7 @@ class AggregativeMedianEstimator(BinaryQuantifier):
|
|||
if len(cls_configs) > 1:
|
||||
models_preds = qp.util.parallel(
|
||||
self._delayed_fit_classifier,
|
||||
((params, training, kwargs) for params in cls_configs),
|
||||
((params, X, y) for params in cls_configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs,
|
||||
asarray=False,
|
||||
|
|
@ -1491,12 +1489,12 @@ class AggregativeMedianEstimator(BinaryQuantifier):
|
|||
else:
|
||||
model = self.base_quantifier
|
||||
model.set_params(**cls_configs[0])
|
||||
predictions = model.classifier_fit_predict(training, **kwargs)
|
||||
models_preds = [(model, predictions)]
|
||||
predictions, labels = model.classifier_fit_predict(X, y)
|
||||
models_preds = [(model, predictions, labels)]
|
||||
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit_aggregation,
|
||||
((setup, training) for setup in itertools.product(models_preds, q_configs)),
|
||||
itertools.product(models_preds, q_configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs,
|
||||
backend='threading'
|
||||
|
|
@ -1505,7 +1503,7 @@ class AggregativeMedianEstimator(BinaryQuantifier):
|
|||
configs = qp.model_selection.expand_grid(self.param_grid)
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit,
|
||||
((params, training) for params in configs),
|
||||
((params, X, y) for params in configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs,
|
||||
backend='threading'
|
||||
|
|
@ -1514,9 +1512,9 @@ class AggregativeMedianEstimator(BinaryQuantifier):
|
|||
|
||||
def _delayed_predict(self, args):
|
||||
model, instances = args
|
||||
return model.quantify(instances)
|
||||
return model.predict(instances)
|
||||
|
||||
def quantify(self, instances):
|
||||
def predict(self, instances):
|
||||
prev_preds = qp.util.parallel(
|
||||
self._delayed_predict,
|
||||
((model, instances) for model in self.models),
|
||||
|
|
|
|||
|
|
@ -375,18 +375,20 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
self.region = region
|
||||
self.random_state = random_state
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
data = LabelledCollection(classif_predictions, labels, classes=self.classes_)
|
||||
self.quantifiers = []
|
||||
if self.n_train_samples==1:
|
||||
self.quantifier.aggregation_fit(classif_predictions, data)
|
||||
self.quantifier.aggregation_fit(classif_predictions, labels)
|
||||
self.quantifiers.append(self.quantifier)
|
||||
else:
|
||||
# model-based bootstrap (only on the aggregative part)
|
||||
full_index = np.arange(len(data))
|
||||
n_examples = len(data)
|
||||
full_index = np.arange(n_examples)
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
for i in range(self.n_train_samples):
|
||||
quantifier = copy.deepcopy(self.quantifier)
|
||||
index = resample(full_index, n_samples=len(data))
|
||||
index = resample(full_index, n_samples=n_examples)
|
||||
classif_predictions_i = classif_predictions.sampling_from_index(index)
|
||||
data_i = data.sampling_from_index(index)
|
||||
quantifier.aggregation_fit(classif_predictions_i, data_i)
|
||||
|
|
@ -415,10 +417,10 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
|
||||
return prev_estim, conf
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split=None):
|
||||
def fit(self, X, y):
|
||||
self.quantifier._check_init_parameters()
|
||||
classif_predictions = self.quantifier.classifier_fit_predict(data, fit_classifier, predict_on=val_split)
|
||||
self.aggregation_fit(classif_predictions, data)
|
||||
classif_predictions, labels = self.quantifier.classifier_fit_predict(X, y)
|
||||
self.aggregation_fit(classif_predictions, labels)
|
||||
return self
|
||||
|
||||
def quantify_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
|
|
@ -446,7 +448,8 @@ class BayesianCC(AggregativeCrispQuantifier, WithConfidenceABC):
|
|||
This method relies on extra dependencies, which have to be installed via:
|
||||
`$ pip install quapy[bayes]`
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param val_split: a float in (0, 1) indicating the proportion of the training data to be used,
|
||||
as a stratified held-out validation set, for generating classifier predictions.
|
||||
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
|
||||
|
|
@ -493,16 +496,17 @@ class BayesianCC(AggregativeCrispQuantifier, WithConfidenceABC):
|
|||
# Dictionary with posterior samples, set when `aggregate` is provided.
|
||||
self._samples = None
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
"""
|
||||
Estimates the misclassification rates.
|
||||
|
||||
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||
as instances, the label predictions issued by the classifier and, as labels, the true labels
|
||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||
:param classif_predictions: array-like with the label predictions returned by the classifier
|
||||
:param labels: array-like with the true labels associated to each classifier prediction
|
||||
"""
|
||||
pred_labels, true_labels = classif_predictions.Xy
|
||||
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels, labels=self.classifier.classes_).astype(float)
|
||||
pred_labels = classif_predictions
|
||||
true_labels = labels
|
||||
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels,
|
||||
labels=self.classifier.classes_)
|
||||
|
||||
def sample_from_posterior(self, classif_predictions):
|
||||
if self._n_and_c_labeled is None:
|
||||
|
|
|
|||
|
|
@ -52,19 +52,19 @@ class MedianEstimator2(BinaryQuantifier):
|
|||
|
||||
def _delayed_fit(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
params, training = args
|
||||
params, X, y = args
|
||||
model = deepcopy(self.base_quantifier)
|
||||
model.set_params(**params)
|
||||
model.fit(training)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
def fit(self, training: LabelledCollection):
|
||||
self._check_binary(training, self.__class__.__name__)
|
||||
def fit(self, X, y):
|
||||
self._check_binary(y, self.__class__.__name__)
|
||||
|
||||
configs = qp.model_selection.expand_grid(self.param_grid)
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit,
|
||||
((params, training) for params in configs),
|
||||
((params, X, y) for params in configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
|
|
@ -95,7 +95,7 @@ class MedianEstimator(BinaryQuantifier):
|
|||
:param base_quantifier: the base, binary quantifier
|
||||
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||
:param param_grid: the grid or parameters towards which the median will be computed
|
||||
:param n_jobs: number of parllel workes
|
||||
:param n_jobs: number of parallel workers
|
||||
"""
|
||||
def __init__(self, base_quantifier: BinaryQuantifier, param_grid: dict, random_state=None, n_jobs=None):
|
||||
self.base_quantifier = base_quantifier
|
||||
|
|
@ -111,61 +111,19 @@ class MedianEstimator(BinaryQuantifier):
|
|||
|
||||
def _delayed_fit(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
params, training = args
|
||||
params, X, y = args
|
||||
model = deepcopy(self.base_quantifier)
|
||||
model.set_params(**params)
|
||||
model.fit(training)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
def _delayed_fit_classifier(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
cls_params, training = args
|
||||
model = deepcopy(self.base_quantifier)
|
||||
model.set_params(**cls_params)
|
||||
predictions = model.classifier_fit_predict(training, predict_on=model.val_split)
|
||||
return (model, predictions)
|
||||
def fit(self, X, y):
|
||||
self._check_binary(y, self.__class__.__name__)
|
||||
|
||||
def _delayed_fit_aggregation(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
((model, predictions), q_params), training = args
|
||||
model = deepcopy(model)
|
||||
model.set_params(**q_params)
|
||||
model.aggregation_fit(predictions, training)
|
||||
return model
|
||||
|
||||
|
||||
def fit(self, training: LabelledCollection):
|
||||
self._check_binary(training, self.__class__.__name__)
|
||||
|
||||
if isinstance(self.base_quantifier, AggregativeQuantifier):
|
||||
cls_configs, q_configs = qp.model_selection.group_params(self.param_grid)
|
||||
|
||||
if len(cls_configs) > 1:
|
||||
models_preds = qp.util.parallel(
|
||||
self._delayed_fit_classifier,
|
||||
((params, training) for params in cls_configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs,
|
||||
asarray=False
|
||||
)
|
||||
else:
|
||||
model = self.base_quantifier
|
||||
model.set_params(**cls_configs[0])
|
||||
predictions = model.classifier_fit_predict(training, predict_on=model.val_split)
|
||||
models_preds = [(model, predictions)]
|
||||
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit_aggregation,
|
||||
((setup, training) for setup in itertools.product(models_preds, q_configs)),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs,
|
||||
asarray=False
|
||||
)
|
||||
else:
|
||||
configs = qp.model_selection.expand_grid(self.param_grid)
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit,
|
||||
((params, training) for params in configs),
|
||||
((params, X, y) for params in configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs,
|
||||
asarray=False
|
||||
|
|
@ -257,12 +215,13 @@ class Ensemble(BaseQuantifier):
|
|||
if self.verbose:
|
||||
print('[Ensemble]' + msg)
|
||||
|
||||
def fit(self, data: qp.data.LabelledCollection, val_split: Union[qp.data.LabelledCollection, float] = None):
|
||||
def fit(self, X, y):
|
||||
|
||||
data = LabelledCollection(X, y)
|
||||
|
||||
if self.policy == 'ds' and not data.binary:
|
||||
raise ValueError(f'ds policy is only defined for binary quantification, but this dataset is not binary')
|
||||
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
||||
# randomly chooses the prevalences for each member of the ensemble (preventing classes with less than
|
||||
|
|
@ -704,10 +663,10 @@ class SCMQ(AggregativeSoftQuantifier):
|
|||
self.merge_fun = merge_fun
|
||||
self.val_split = val_split
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
for quantifier in self.quantifiers:
|
||||
quantifier.classifier = self.classifier
|
||||
quantifier.aggregation_fit(classif_predictions, data)
|
||||
quantifier.aggregation_fit(classif_predictions, labels)
|
||||
return self
|
||||
|
||||
def aggregate(self, classif_predictions: np.ndarray):
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class DMx(BaseQuantifier):
|
|||
|
||||
return distributions
|
||||
|
||||
def fit(self, data: LabelledCollection):
|
||||
def fit(self, X, y):
|
||||
"""
|
||||
Generates the validation distributions out of the training data (covariates).
|
||||
The validation distributions have shape `(n, nfeats, nbins)`, with `n` the number of classes, `nfeats`
|
||||
|
|
@ -109,10 +109,9 @@ class DMx(BaseQuantifier):
|
|||
training data labelled with class `i`; while `dij = di[j]` is the discrete distribution for feature j in
|
||||
training data labelled with class `i`, and `dij[k]` is the fraction of instances with a value in the `k`-th bin.
|
||||
|
||||
:param data: the training set
|
||||
:param X: array-like of shape `(n_samples, n_features)`, the training instances
|
||||
:param y: array-like of shape `(n_samples,)`, the labels
|
||||
"""
|
||||
X, y = data.Xy
|
||||
|
||||
self.nfeats = X.shape[1]
|
||||
self.feat_ranges = _get_features_range(X)
|
||||
|
||||
|
|
@ -147,53 +146,53 @@ class DMx(BaseQuantifier):
|
|||
return F.argmin_prevalence(loss, n_classes, method=self.search)
|
||||
|
||||
|
||||
class ReadMe(BaseQuantifier):
|
||||
|
||||
def __init__(self, bootstrap_trials=100, bootstrap_range=100, bagging_trials=100, bagging_range=25, **vectorizer_kwargs):
|
||||
raise NotImplementedError('under development ...')
|
||||
self.bootstrap_trials = bootstrap_trials
|
||||
self.bootstrap_range = bootstrap_range
|
||||
self.bagging_trials = bagging_trials
|
||||
self.bagging_range = bagging_range
|
||||
self.vectorizer_kwargs = vectorizer_kwargs
|
||||
|
||||
def fit(self, data: LabelledCollection):
|
||||
X, y = data.Xy
|
||||
self.vectorizer = CountVectorizer(binary=True, **self.vectorizer_kwargs)
|
||||
X = self.vectorizer.fit_transform(X)
|
||||
self.class_conditional_X = {i: X[y==i] for i in range(data.classes_)}
|
||||
|
||||
def predict(self, X):
|
||||
X = self.vectorizer.transform(X)
|
||||
|
||||
# number of features
|
||||
num_docs, num_feats = X.shape
|
||||
|
||||
# bootstrap
|
||||
p_boots = []
|
||||
for _ in range(self.bootstrap_trials):
|
||||
docs_idx = np.random.choice(num_docs, size=self.bootstra_range, replace=False)
|
||||
class_conditional_X = {i: X[docs_idx] for i, X in self.class_conditional_X.items()}
|
||||
Xboot = X[docs_idx]
|
||||
|
||||
# bagging
|
||||
p_bags = []
|
||||
for _ in range(self.bagging_trials):
|
||||
feat_idx = np.random.choice(num_feats, size=self.bagging_range, replace=False)
|
||||
class_conditional_Xbag = {i: X[:, feat_idx] for i, X in class_conditional_X.items()}
|
||||
Xbag = Xboot[:,feat_idx]
|
||||
p = self.std_constrained_linear_ls(Xbag, class_conditional_Xbag)
|
||||
p_bags.append(p)
|
||||
p_boots.append(np.mean(p_bags, axis=0))
|
||||
|
||||
p_mean = np.mean(p_boots, axis=0)
|
||||
p_std = np.std(p_bags, axis=0)
|
||||
|
||||
return p_mean
|
||||
|
||||
|
||||
def std_constrained_linear_ls(self, X, class_cond_X: dict):
|
||||
pass
|
||||
# class ReadMe(BaseQuantifier):
|
||||
#
|
||||
# def __init__(self, bootstrap_trials=100, bootstrap_range=100, bagging_trials=100, bagging_range=25, **vectorizer_kwargs):
|
||||
# raise NotImplementedError('under development ...')
|
||||
# self.bootstrap_trials = bootstrap_trials
|
||||
# self.bootstrap_range = bootstrap_range
|
||||
# self.bagging_trials = bagging_trials
|
||||
# self.bagging_range = bagging_range
|
||||
# self.vectorizer_kwargs = vectorizer_kwargs
|
||||
#
|
||||
# def fit(self, data: LabelledCollection):
|
||||
# X, y = data.Xy
|
||||
# self.vectorizer = CountVectorizer(binary=True, **self.vectorizer_kwargs)
|
||||
# X = self.vectorizer.fit_transform(X)
|
||||
# self.class_conditional_X = {i: X[y==i] for i in range(data.classes_)}
|
||||
#
|
||||
# def predict(self, X):
|
||||
# X = self.vectorizer.transform(X)
|
||||
#
|
||||
# # number of features
|
||||
# num_docs, num_feats = X.shape
|
||||
#
|
||||
# # bootstrap
|
||||
# p_boots = []
|
||||
# for _ in range(self.bootstrap_trials):
|
||||
# docs_idx = np.random.choice(num_docs, size=self.bootstra_range, replace=False)
|
||||
# class_conditional_X = {i: X[docs_idx] for i, X in self.class_conditional_X.items()}
|
||||
# Xboot = X[docs_idx]
|
||||
#
|
||||
# # bagging
|
||||
# p_bags = []
|
||||
# for _ in range(self.bagging_trials):
|
||||
# feat_idx = np.random.choice(num_feats, size=self.bagging_range, replace=False)
|
||||
# class_conditional_Xbag = {i: X[:, feat_idx] for i, X in class_conditional_X.items()}
|
||||
# Xbag = Xboot[:,feat_idx]
|
||||
# p = self.std_constrained_linear_ls(Xbag, class_conditional_Xbag)
|
||||
# p_bags.append(p)
|
||||
# p_boots.append(np.mean(p_bags, axis=0))
|
||||
#
|
||||
# p_mean = np.mean(p_boots, axis=0)
|
||||
# p_std = np.std(p_bags, axis=0)
|
||||
#
|
||||
# return p_mean
|
||||
#
|
||||
#
|
||||
# def std_constrained_linear_ls(self, X, class_cond_X: dict):
|
||||
# pass
|
||||
|
||||
|
||||
def _get_features_range(X):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from sklearn.linear_model import LogisticRegression
|
||||
import quapy as qp
|
||||
from method.aggregative import *
|
||||
|
||||
|
|
@ -5,11 +6,20 @@ datasets = qp.datasets.UCI_MULTICLASS_DATASETS[1]
|
|||
data = qp.datasets.fetch_UCIMulticlassDataset(datasets)
|
||||
train, test = data.train_test
|
||||
|
||||
quant = EMQ()
|
||||
quant.fit(*train.Xy)
|
||||
prev = quant.predict(test.X)
|
||||
Xtr, ytr = train.Xy
|
||||
Xte = test.X
|
||||
|
||||
quant = EMQ(LogisticRegression(), calib='bcts')
|
||||
quant.fit(Xtr, ytr)
|
||||
prev = quant.predict(Xte)
|
||||
|
||||
print(prev)
|
||||
post = quant.predict_proba(Xte)
|
||||
print(post)
|
||||
post = quant.classify(Xte)
|
||||
print(post)
|
||||
|
||||
# AggregativeMedianEstimator()
|
||||
|
||||
|
||||
# test CC, prevent from doing 5FCV for nothing
|
||||
|
|
|
|||
Loading…
Reference in New Issue