fixing threshold optimization-based techniques

This commit is contained in:
Alejandro Moreo Fernandez 2024-01-17 09:33:39 +01:00
parent 6d53b68d7f
commit 896fa042d6
4 changed files with 241 additions and 103 deletions

1
.gitignore vendored
View File

@ -158,3 +158,4 @@ TweetSentQuant
*.png

View File

@ -0,0 +1,136 @@
from copy import deepcopy
import quapy as qp
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from quapy.classification.methods import LowRankLogisticRegression
from quapy.method.meta import QuaNet
from quapy.protocol import APP
from quapy.method.aggregative import CC, ACC, PCC, PACC, MAX, MS, MS2, EMQ, HDy, newSVMAE, T50, X
from quapy.method.meta import EHDy
import numpy as np
import os
import pickle
import itertools
import argparse
import torch
import shutil
N_JOBS = -1
CUDA_N_JOBS = 2
ENSEMBLE_N_JOBS = -1
qp.environ['SAMPLE_SIZE'] = 100
def newLR():
return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
def calibratedLR():
return CalibratedClassifierCV(LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1))
__C_range = np.logspace(-3, 3, 7)
lr_params = {'classifier__C': __C_range, 'classifier__class_weight': [None, 'balanced']}
svmperf_params = {'classifier__C': __C_range}
def quantification_models():
yield 'acc', ACC(newLR()), lr_params
yield 'T50', T50(newLR()), lr_params
yield 'X', X(newLR()), lr_params
yield 'MAX', MAX(newLR()), lr_params
yield 'MS', MS(newLR()), lr_params
yield 'MS2', MS2(newLR()), lr_params
def evaluate_experiment(true_prevalences, estim_prevalences):
print('\nEvaluation Metrics:\n' + '=' * 22)
for eval_measure in [qp.error.mae, qp.error.mrae]:
err = eval_measure(true_prevalences, estim_prevalences)
print(f'\t{eval_measure.__name__}={err:.4f}')
print()
def result_path(path, dataset_name, model_name, run, optim_loss):
return os.path.join(path, f'{dataset_name}-{model_name}-run{run}-{optim_loss}.pkl')
def is_already_computed(dataset_name, model_name, run, optim_loss):
return os.path.exists(result_path(args.results, dataset_name, model_name, run, optim_loss))
def save_results(dataset_name, model_name, run, optim_loss, *results):
rpath = result_path(args.results, dataset_name, model_name, run, optim_loss)
qp.util.create_parent_dir(rpath)
with open(rpath, 'wb') as foo:
pickle.dump(tuple(results), foo, pickle.HIGHEST_PROTOCOL)
def run(experiment):
optim_loss, dataset_name, (model_name, model, hyperparams) = experiment
if dataset_name in ['acute.a', 'acute.b', 'iris.1']: return
collection = qp.datasets.fetch_UCILabelledCollection(dataset_name)
for run, data in enumerate(qp.data.Dataset.kFCV(collection, nfolds=5, nrepeats=1)):
if is_already_computed(dataset_name, model_name, run=run, optim_loss=optim_loss):
print(f'result for dataset={dataset_name} model={model_name} loss={optim_loss} run={run+1}/5 already computed.')
continue
print(f'running dataset={dataset_name} model={model_name} loss={optim_loss} run={run+1}/5')
# model selection (hyperparameter optimization for a quantification-oriented loss)
train, test = data.train_test
train, val = train.split_stratified()
if hyperparams is not None:
model_selection = qp.model_selection.GridSearchQ(
deepcopy(model),
param_grid=hyperparams,
protocol=APP(val, n_prevalences=21, repeats=25),
error=optim_loss,
refit=True,
timeout=60*60,
verbose=True
)
model_selection.fit(data.training)
model = model_selection.best_model()
best_params = model_selection.best_params_
else:
model.fit(data.training)
best_params = {}
# model evaluation
true_prevalences, estim_prevalences = qp.evaluation.prediction(
model,
protocol=APP(test, n_prevalences=21, repeats=100)
)
test_true_prevalence = data.test.prevalence()
evaluate_experiment(true_prevalences, estim_prevalences)
save_results(dataset_name, model_name, run, optim_loss,
true_prevalences, estim_prevalences,
data.training.prevalence(), test_true_prevalence,
best_params)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run experiments for Tweeter Sentiment Quantification')
parser.add_argument('--results', metavar='RESULT_PATH', type=str, default='results_tmp',
help='path to the directory where to store the results')
parser.add_argument('--svmperfpath', metavar='SVMPERF_PATH', type=str, default='../svm_perf_quantification',
help='path to the directory with svmperf')
args = parser.parse_args()
print(f'Result folder: {args.results}')
np.random.seed(0)
qp.environ['SVMPERF_HOME'] = args.svmperfpath
optim_losses = ['mae']
datasets = qp.datasets.UCI_DATASETS
models = quantification_models()
qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=N_JOBS)
shutil.rmtree(args.checkpointdir, ignore_errors=True)

View File

@ -66,6 +66,23 @@ def prevalence_from_probabilities(posteriors, binarize: bool = False):
return prevalences
def as_binary_prevalence(positive_prevalence: float, clip_if_necessary=False):
"""
Helper that, given a float representing the prevalence for the positive class, returns a np.ndarray of two
values representing a binary distribution.
:param positive_prevalence: prevalence for the positive class
:param clip_if_necessary: if True, clips the value in [0,1] in order to guarantee the resulting distribution
is valid. If False, it then checks that the value is in the valid range, and raises an error if not.
:return: np.ndarray of shape `(2,)`
"""
if clip_if_necessary:
positive_prevalence = np.clip(positive_prevalence, 0, 1)
else:
assert 0 <= positive_prevalence <= 1, 'the value provided is not a valid prevalence for the positive class'
return np.asarray([1-positive_prevalence, positive_prevalence])
def HellingerDistance(P, Q) -> float:
"""
Computes the Hellingher Distance (HD) between (discretized) distributions `P` and `Q`.

View File

@ -159,28 +159,25 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
"""
self.classifier_ = classifier
@abstractmethod
def classify(self, instances):
"""
Provides the label predictions for the given instances. The predictions should respect the format expected by
:meth:`aggregate`, e.g., posterior probabilities for probabilistic quantifiers, or crisp predictions for
non-probabilistic quantifiers
non-probabilistic quantifiers. The default one is "decision_function".
:param instances: array-like of shape `(n_instances, n_features,)`
:return: np.ndarray of shape `(n_instances,)` with label predictions
"""
...
return getattr(self, self._classifier_method())(instances)
@abstractmethod
def _classifier_method(self):
"""
Name of the method that must be used for issuing label predictions.
Name of the method that must be used for issuing label predictions. The default one is "decision_function".
:return: string
"""
...
return 'decision_function'
@abstractmethod
def _check_classifier(self, adapt_if_necessary=False):
"""
Guarantees that the underlying classifier implements the method required for issuing predictions, i.e.,
@ -188,7 +185,8 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
:param adapt_if_necessary: if True, the method will try to comply with the required specifications
"""
...
assert hasattr(self.classifier, self._classifier_method()), \
f"the method does not implement the required {self._classifier_method()} method"
def quantify(self, instances):
"""
@ -229,32 +227,15 @@ class AggregativeCrispQuantifier(AggregativeQuantifier, ABC):
Quantifiers by implementing specifications about crisp predictions.
"""
def classify(self, instances):
"""
Provides the label (crisp) predictions for the given instances.
:param instances: array-like of shape `(n_instances, n_dimensions,)`
:return: np.ndarray of shape `(n_instances,)` with label predictions
"""
return self.classifier.predict(instances)
def _classifier_method(self):
"""
Name of the method that must be used for issuing label predictions.
Name of the method that must be used for issuing label predictions. For crisp quantifiers, the method
is 'predict', that returns an array of shape `(n_instances,)` of label predictions.
:return: the string "predict", i.e., the standard method name for scikit-learn hard predictions
"""
return 'predict'
def _check_classifier(self, adapt_if_necessary=False):
"""
Guarantees that the underlying classifier implements the method indicated by the :meth:`_classifier_method`
:param adapt_if_necessary: unused, added for compatibility
"""
assert hasattr(self.classifier, self._classifier_method()), \
f"the method does not implement the required {self._classifier_method()} method"
class AggregativeSoftQuantifier(AggregativeQuantifier, ABC):
"""
@ -264,18 +245,11 @@ class AggregativeSoftQuantifier(AggregativeQuantifier, ABC):
about soft predictions.
"""
def classify(self, instances):
"""
Provides the posterior probabilities for the given instances.
:param instances: array-like of shape `(n_instances, n_dimensions,)`
:return: np.ndarray of shape `(n_instances, n_classes,)` with posterior probabilities
"""
return self.classifier.predict_proba(instances)
def _classifier_method(self):
"""
Name of the method that must be used for issuing label predictions.
Name of the method that must be used for issuing label predictions. For probabilistic quantifiers, the method
is 'predict_proba', that returns an array of shape `(n_instances, n_dimensions,)` with posterior
probabilities.
:return: the string "predict_proba", i.e., the standard method name for scikit-learn soft predictions
"""
@ -731,7 +705,7 @@ class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
prev_estimations.append(prev_selected)
class1_prev = np.median(prev_estimations)
return np.asarray([1 - class1_prev, class1_prev])
return F.as_binary_prevalence(class1_prev)
class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
@ -793,7 +767,7 @@ class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
return divergence(Px_train, Px_test)
class1_prev = self._ternary_search(f=distribution_distance, left=0, right=1, tol=self.tol)
return np.asarray([1 - class1_prev, class1_prev])
return F.as_binary_prevalence(class1_prev)
class SMM(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
@ -825,9 +799,7 @@ class SMM(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
Px_mean = np.mean(Px)
class1_prev = (Px_mean - self.Pxy0_mean)/(self.Pxy1_mean - self.Pxy0_mean)
class1_prev = np.clip(class1_prev, 0, 1)
return np.asarray([1 - class1_prev, class1_prev])
return F.as_binary_prevalence(class1_prev, clip_if_necessary=True)
class DMy(AggregativeSoftQuantifier):
@ -1086,7 +1058,7 @@ def newSVMRAE(svmperf_base=None, C=1):
return newELM(svmperf_base, loss='mrae', C=C)
class ThresholdOptimization(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
class ThresholdOptimization(BinaryAggregativeQuantifier):
"""
Abstract class of Threshold Optimization variants for :class:`ACC` as proposed by
`Forman 2006 <https://dl.acm.org/doi/abs/10.1145/1150402.1150423>`_ and
@ -1110,13 +1082,8 @@ class ThresholdOptimization(AggregativeSoftQuantifier, BinaryAggregativeQuantifi
self.val_split = val_split
self.n_jobs = qp._get_njobs(n_jobs)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
P, y = classif_predictions.Xy
self.tpr, self.fpr, self.threshold = self._optimize_threshold(y, P)
return self
@abstractmethod
def _condition(self, tpr, fpr) -> float:
def condition(self, tpr, fpr) -> float:
"""
Implements the criterion according to which the threshold should be selected.
This function should return the (float) score to be minimized.
@ -1127,46 +1094,63 @@ class ThresholdOptimization(AggregativeSoftQuantifier, BinaryAggregativeQuantifi
"""
...
def _optimize_threshold(self, y, probabilities):
def discard(self, tpr, fpr) -> bool:
"""
Indicates whether a combination of tpr and fpr should be discarded
:param tpr: float, true positive rate
:param fpr: float, false positive rate
:return: true if the combination is to be discarded, false otherwise
"""
return (tpr + fpr) == 0
def _eval_candidate_thresholds(self, decision_scores, y):
"""
Seeks for the best `tpr` and `fpr` according to the score obtained at different
decision thresholds. The scoring function is implemented in function `_condition`.
:param decision_scores: array-like with the classification scores
:param y: predicted labels for the validation set (or for the training set via `k`-fold cross validation)
:param probabilities: array-like with the posterior probabilities
:return: best `tpr` and `fpr` and `threshold` according to `_condition`
"""
best_candidate_threshold_score = None
best_tpr = 0
best_fpr = 0
candidate_thresholds = np.unique(probabilities[:, self.pos_label])
candidate_thresholds = np.unique(decision_scores)
candidates = []
scores = []
for candidate_threshold in candidate_thresholds:
y_ = self.classes_[1*(probabilities[:,1]>candidate_threshold)]
#y_ = [self.pos_label if p > candidate_threshold else self.neg_label for p in probabilities[:, 1]]
y_ = self.classes_[1 * (decision_scores > candidate_threshold)]
TP, FP, FN, TN = self._compute_table(y, y_)
tpr = self._compute_tpr(TP, FP)
fpr = self._compute_fpr(FP, TN)
condition_score = self._condition(tpr, fpr)
if best_candidate_threshold_score is None or condition_score < best_candidate_threshold_score:
best_candidate_threshold_score = condition_score
best_tpr = tpr
best_fpr = fpr
if not self.discard(tpr, fpr):
candidate_score = self.condition(tpr, fpr)
candidates.append([tpr, fpr, candidate_threshold])
scores.append(candidate_score)
return best_tpr, best_fpr, best_candidate_threshold_score
if len(candidates) == 0:
# if no candidate gives rise to a valid combination of tpr and fpr, this method defaults to the standard
# classify & count; this is akin to assign tpr=1, fpr=0, threshold=0
tpr, fpr, threshold, score = 1, 0, 0, 0
candidates.append([tpr, fpr, threshold, score])
def aggregate(self, classif_predictions):
class_scores = classif_predictions[:, self.pos_label]
prev_estim = np.mean(class_scores > self.threshold)
if self.tpr - self.fpr != 0:
prevs_estim = np.clip((prev_estim - self.fpr) / (self.tpr - self.fpr), 0, 1)
candidates = np.asarray(candidates)
candidates = candidates[np.argsort(scores)] # sort candidates by candidate_score
return candidates
def aggregate_with_threshold(self, classif_predictions, tpr, fpr, threshold):
prevs_estim = np.mean(classif_predictions > threshold)
if tpr - fpr != 0:
prevs_estim = np.clip((prevs_estim - fpr) / (tpr - fpr), 0, 1)
prevs_estim = np.array((1 - prevs_estim, prevs_estim))
return prevs_estim
def _compute_table(self, y, y_):
TP = np.logical_and(y == y_, y == self.classes_[1]).sum()
FP = np.logical_and(y != y_, y == self.classes_[0]).sum()
FN = np.logical_and(y != y_, y == self.classes_[1]).sum()
TN = np.logical_and(y == y_, y == self.classes_[0]).sum()
TP = np.logical_and(y == y_, y == self.pos_label).sum()
FP = np.logical_and(y != y_, y == self.neg_label).sum()
FN = np.logical_and(y != y_, y == self.pos_label).sum()
TN = np.logical_and(y == y_, y == self.neg_label).sum()
return TP, FP, FN, TN
def _compute_tpr(self, TP, FP):
@ -1179,13 +1163,23 @@ class ThresholdOptimization(AggregativeSoftQuantifier, BinaryAggregativeQuantifi
return 0
return FP / (FP + TN)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
# the standard behavior is to keep the best threshold only
decision_scores, y = classif_predictions.Xy
self.tpr, self.fpr, self.threshold = self._eval_candidate_thresholds(decision_scores, y)[0]
return self
def aggregate(self, classif_predictions: np.ndarray):
# the standard behavior is to compute the adjusted count using the best threshold found
return self.aggregate_with_threshold(classif_predictions, self.tpr, self.fpr, self.threshold)
class T50(ThresholdOptimization):
"""
Threshold Optimization variant for :class:`ACC` as proposed by
`Forman 2006 <https://dl.acm.org/doi/abs/10.1145/1150402.1150423>`_ and
`Forman 2008 <https://link.springer.com/article/10.1007/s10618-008-0097-y>`_ that looks
for the threshold that makes `tpr` cosest to 0.5.
for the threshold that makes `tpr` closest to 0.5.
The goal is to bring improved stability to the denominator of the adjustment.
:param classifier: a sklearn's Estimator that generates a classifier
@ -1200,7 +1194,7 @@ class T50(ThresholdOptimization):
def __init__(self, classifier: BaseEstimator, val_split=5):
super().__init__(classifier, val_split)
def _condition(self, tpr, fpr) -> float:
def condition(self, tpr, fpr) -> float:
return abs(tpr - 0.5)
@ -1224,7 +1218,7 @@ class MAX(ThresholdOptimization):
def __init__(self, classifier: BaseEstimator, val_split=5):
super().__init__(classifier, val_split)
def _condition(self, tpr, fpr) -> float:
def condition(self, tpr, fpr) -> float:
# MAX strives to maximize (tpr - fpr), which is equivalent to minimize (fpr - tpr)
return (fpr - tpr)
@ -1249,7 +1243,7 @@ class X(ThresholdOptimization):
def __init__(self, classifier: BaseEstimator, val_split=5):
super().__init__(classifier, val_split)
def _condition(self, tpr, fpr) -> float:
def condition(self, tpr, fpr) -> float:
return abs(1 - (tpr + fpr))
@ -1272,21 +1266,22 @@ class MS(ThresholdOptimization):
def __init__(self, classifier: BaseEstimator, val_split=5):
super().__init__(classifier, val_split)
def _condition(self, tpr, fpr) -> float:
pass
def condition(self, tpr, fpr) -> float:
return 1
def _optimize_threshold(self, y, probabilities):
tprs = []
fprs = []
candidate_thresholds = np.unique(probabilities[:, 1])
for candidate_threshold in candidate_thresholds:
y_ = [self.classes_[1] if p > candidate_threshold else self.classes_[0] for p in probabilities[:, 1]]
TP, FP, FN, TN = self._compute_table(y, y_)
tpr = self._compute_tpr(TP, FP)
fpr = self._compute_fpr(FP, TN)
tprs.append(tpr)
fprs.append(fpr)
return np.median(tprs), np.median(fprs)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
# keeps all candidates
decision_scores, y = classif_predictions.Xy
self.tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
return self
def aggregate(self, classif_predictions: np.ndarray):
prevalences = []
for tpr, fpr, threshold in self.tprs_fprs_thresholds:
pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1]
prevalences.append(pos_prev)
median = np.median(prevalences)
return F.as_binary_prevalence(median)
class MS2(MS):
@ -1309,19 +1304,8 @@ class MS2(MS):
def __init__(self, classifier: BaseEstimator, val_split=5):
super().__init__(classifier, val_split)
def _optimize_threshold(self, y, probabilities):
tprs = [0, 1]
fprs = [0, 1]
candidate_thresholds = np.unique(probabilities[:, 1])
for candidate_threshold in candidate_thresholds:
y_ = [self.classes_[1] if p > candidate_threshold else self.classes_[0] for p in probabilities[:, 1]]
TP, FP, FN, TN = self._compute_table(y, y_)
tpr = self._compute_tpr(TP, FP)
fpr = self._compute_fpr(FP, TN)
if (tpr - fpr) > 0.25:
tprs.append(tpr)
fprs.append(fpr)
return np.median(tprs), np.median(fprs)
def discard(self, tpr, fpr) -> bool:
return (tpr-fpr) <= 0.25
class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier):