improving custom quantifier

This commit is contained in:
Alejandro Moreo Fernandez 2024-03-06 11:45:08 +01:00
parent ea7a574185
commit d81bf305a3
5 changed files with 141 additions and 43 deletions

View File

@ -12,7 +12,7 @@ if PROBLEM == 'binary':
gen_datasets = gen_bin_datasets
elif PROBLEM == 'multiclass':
qp.environ['SAMPLE_SIZE'] = 250
NUM_TEST = 100
NUM_TEST = 1000
gen_datasets = gen_multi_datasets
@ -34,13 +34,14 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier
# instances of ClassifierAccuracyPrediction are bound to the evaluation measure, so they
# must be nested in the acc-for
for acc_name, acc_fn in gen_acc_measure():
print(f'\tfor measure {acc_name}')
for (method_name, method) in gen_CAP(h, acc_fn, with_oracle=ORACLE):
result_path = getpath(basedir, cls_name, acc_name, dataset_name, method_name)
if os.path.exists(result_path):
print(f'\t{method_name}-{acc_name} exists, skipping')
print(f'\t\t{method_name}-{acc_name} exists, skipping')
continue
print(f'\t{method_name} computing...')
print(f'\t\t{method_name} computing...')
method, t_train = fit_method(method, V)
estim_accs, t_test_ave = predictionsCAP(method, test_prot, ORACLE)
save_json_result(result_path, true_accs[acc_name], estim_accs, t_train, t_test_ave)
@ -49,10 +50,10 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier
# be nested to the predictions to speed up things
for (method_name, method) in gen_CAP_cont_table(h):
if not any_missing(basedir, cls_name, dataset_name, method_name):
print(f'\tmethod {method_name} has all results already computed. Skipping.')
print(f'\t\tmethod {method_name} has all results already computed. Skipping.')
continue
print(f'\tmethod {method_name} computing...')
print(f'\t\tmethod {method_name} computing...')
method, t_train = fit_method(method, V)
estim_accs_dict, t_test_ave = predictionsCAPcont_table(method, test_prot, gen_acc_measure, ORACLE)
@ -65,7 +66,6 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier
# generate diagonal plots
print('generating plots')
for (cls_name, _), (acc_name, _) in itertools.product(gen_classifiers(), gen_acc_measure()):
methods = get_method_names()
plot_diagonal(basedir, cls_name, acc_name)
for dataset_name, _ in gen_datasets(only_names=True):
plot_diagonal(basedir, cls_name, acc_name, dataset_name=dataset_name)

View File

@ -329,7 +329,7 @@ class SebastianiCAP(ClassifierAccuracyPrediction):
class PabloCAP(ClassifierAccuracyPrediction):
def __init__(self, h, acc_fn, q_class, n_val_samples=50, aggr='mean'):
def __init__(self, h, acc_fn, q_class, n_val_samples=100, aggr='mean'):
self.h = h
self.acc = acc_fn
self.q = q_class(h)
@ -434,7 +434,7 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
add_maxinfsoft=False):
self.h = h
self.acc = acc
self.q = EmptySaveQuantifier(q_class)
self.q = EmptySafeQuantifier(q_class)
self.add_X = add_X
self.add_posteriors = add_posteriors
self.add_maxconf = add_maxconf
@ -490,7 +490,7 @@ class QuAccNxN(CAPContingencyTableQ, QuAcc):
y_i = true_labels[pred_labels==class_i]
data_i = LabelledCollection(X_dot_i, y_i, classes=val.classes_)
q_i = EmptySaveQuantifier(deepcopy(self.q_class))
q_i = EmptySafeQuantifier(deepcopy(self.q_class))
q_i.fit(data_i)
self.q.append(q_i)
@ -518,7 +518,7 @@ def safehstack(X, P):
return XP
class EmptySaveQuantifier(BaseQuantifier):
class EmptySafeQuantifier(BaseQuantifier):
def __init__(self, surrogate_quantifier: BaseQuantifier):
self.surrogate = surrogate_quantifier
@ -616,11 +616,12 @@ class ATC(ClassifierAccuracyPrediction):
class DoC(ClassifierAccuracyPrediction):
def __init__(self, h, acc, sample_size, num_samples=500):
def __init__(self, h, acc, sample_size, num_samples=500, clip_vals=(0,1)):
self.h = h
self.acc = acc
self.sample_size = sample_size
self.num_samples = num_samples
self.clip_vals = clip_vals
def _get_post_stats(self, X, y):
P = get_posteriors_from_h(self.h, X)
@ -660,6 +661,8 @@ class DoC(ClassifierAccuracyPrediction):
P = get_posteriors_from_h(self.h, X)
mc = max_conf(P)
acc_pred = self.predict_regression(mc)[0]
if self.clip_vals is not None:
acc_pred = np.clip(acc_pred, *self.clip_vals)
return acc_pred
"""

View File

@ -6,12 +6,14 @@ from glob import glob
from pathlib import Path
from time import time
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score, f1_score
from sklearn.datasets import fetch_rcv1
from sklearn.datasets import fetch_rcv1, fetch_20newsgroups
from sklearn.model_selection import GridSearchCV
from ClassifierAccuracy.models_multiclass import *
from ClassifierAccuracy.util.tabular import Table
from quapy.method.aggregative import EMQ, ACC, KDEyML
from quapy.data import LabelledCollection
@ -41,6 +43,16 @@ def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledColl
else:
dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
yield dataset_name, split(dataset)
train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))
tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
Xtr = tfidf.fit_transform(train.data)
Xte = tfidf.transform((test.data))
train = LabelledCollection(instances=Xtr, labels=train.target)
U = LabelledCollection(instances=Xte, labels=test.target)
T, V = train.split_stratified(train_prop=0.5, random_state=0)
yield "20news", (T, V, U)
def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
@ -71,7 +83,7 @@ def gen_CAP(h, acc_fn, with_oracle=False)->[str, ClassifierAccuracyPrediction]:
#yield 'SebCAP-KDE', SebastianiCAP(h, acc_fn, KDEyML)
#yield 'SebCAPweight', SebastianiCAP(h, acc_fn, ACC, alpha=0)
#yield 'PabCAP', PabloCAP(h, acc_fn, ACC)
#yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median')
yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median')
yield 'ATC-MC', ATC(h, acc_fn, scoring_fn='maxconf')
#yield 'ATC-NE', ATC(h, acc_fn, scoring_fn='neg_entropy')
yield 'DoC', DoC(h, acc_fn, sample_size=qp.environ['SAMPLE_SIZE'])
@ -288,7 +300,7 @@ def get_dataset_stats(path, test_prot, L, V):
def gen_tables(basedir, datasets):
from tabular import Table
mock_h = LogisticRegression(),
methods = [method for method, _ in gen_CAP(mock_h, None)] + [method for method, _ in gen_CAP_cont_table(mock_h)]
@ -313,7 +325,7 @@ def gen_tables(basedir, datasets):
classifier = classifiers[0]
for metric in [measure for measure, _ in gen_acc_measure()]:
table = Table(datasets, methods)
table = Table(datasets, methods, prec_mean=5, clean_zero=True)
for method, dataset in itertools.product(methods, datasets):
path = getpath(basedir, classifier, metric, dataset, method)
if not os.path.exists(path):

View File

@ -16,7 +16,7 @@ def plot_diagonal(basedir, cls_name, measure_name, dataset_name='*'):
xs.append(results[method_name]['true_acc'])
ys.append(results[method_name]['estim_acc'])
plotsubdir = 'all' if dataset_name=='*' else dataset_name
save_path = join('plots', basedir, plotsubdir, 'diagonal.png')
save_path = join('plots', basedir, measure_name, plotsubdir, 'diagonal.png')
_plot_diagonal(methods, xs, ys, save_path, measure_name)
@ -31,7 +31,7 @@ def _plot_diagonal(methods_names, true_xs, estim_ys, save_path, measure_name, ti
plt.plot([0, 1], [0, 1], color='black', linestyle='--')
for (method_name, xs, ys) in zip(methods_names, true_xs, estim_ys):
plt.scatter(xs, ys, label=f'{method_name}', alpha=0.6)
plt.scatter(xs, ys, label=f'{method_name}', alpha=0.5, linewidths=0)
plt.legend()

View File

@ -1,33 +1,79 @@
import quapy as qp
from quapy.data import LabelledCollection
from quapy.method.base import BinaryQuantifier
from quapy.method.base import BinaryQuantifier, BaseQuantifier
from quapy.model_selection import GridSearchQ
from quapy.method.aggregative import AggregativeSoftQuantifier
from quapy.protocol import APP
import numpy as np
from sklearn.linear_model import LogisticRegression
from time import time
# Define a custom quantifier: for this example, we will consider a new quantification algorithm that uses a
# logistic regressor for generating posterior probabilities, and then applies a custom threshold value to the
# posteriors. Since the quantifier internally uses a classifier, it is an aggregative quantifier; and since it
# relies on posterior probabilities, it is a probabilistic-aggregative quantifier. Note also it has an
# internal hyperparameter (let say, alpha) which is the decision threshold. Let's also assume the quantifier
# is binary, for simplicity.
# relies on posterior probabilities, it is a probabilistic-aggregative quantifier (aka AggregativeSoftQuantifier).
# Note also it has an internal hyperparameter (let say, alpha) which is the decision threshold.
#
# Let's also assume the quantifier is binary, for simplicity. Any quantifier (i.e., any subclass of BaseQuantifier)
# is required to implement the "fit" and "quantify" methods. Aggregative quantifiers are special subtypes of base
# quantifiers, i.e., are quantifiers that undertake a classification-phase followed by an aggregation-phase. QuaPy
# already implements most common functionality, and requires the developer to simply implement the "aggregation_fit"
# and the "aggregation" methods.
#
# We are providing two implementations of the same method to illustrate this characteristic of QuaPy. Let us begin
# with the general case, in which we implement a (base) quantifier
class MyQuantifier(BaseQuantifier):
class MyQuantifier(AggregativeSoftQuantifier, BinaryQuantifier):
def __init__(self, classifier, alpha=0.5):
self.alpha = alpha
# aggregative quantifiers have an internal self.classifier attribute
self.classifier = classifier
def fit(self, data: LabelledCollection, fit_classifier=True):
assert fit_classifier, 'this quantifier needs to fit the classifier!'
# in general, we would need to implement the method fit(self, data: LabelledCollection, fit_classifier=True,
# val_split=None); this would amount to:
def fit(self, data: LabelledCollection):
assert data.n_classes==2, \
'this quantifier is only valid for binary problems [abort]'
self.classifier.fit(*data.Xy)
return self
# in general, we would need to implement the method quantify(self, instances) but, since this method is of
# type aggregative, we can simply implement the method aggregate, which has the following interface
# in general, we would need to implement the method quantify(self, instances); this would amount to:
def quantify(self, instances):
assert hasattr(self.classifier, 'predict_proba'), \
'the underlying classifier is not probabilistic! [abort]'
posterior_probabilities = self.classifier.predict_proba(instances)
positive_probabilities = posterior_probabilities[:, 1]
crisp_decisions = positive_probabilities > self.alpha
pos_prev = crisp_decisions.mean()
neg_prev = 1 - pos_prev
return np.asarray([neg_prev, pos_prev])
# Note that the above implementation contains a lot of boilerplate code. Many parts can be omitted since QuaPy
# provides implementations for them. Some of these routines (like, for example, training a classifier and generating
# posterior probabilities) are often carried out in a k-fold cross-validation manner. These, along with many other
# common routines are already provided by highly-optimized routines in QuaPy. Let's see a much better implementation
# of the method, now adhering to the AggregativeSoftQuantifier:
class MyAggregativeSoftQuantifier(AggregativeSoftQuantifier, BinaryQuantifier):
def __init__(self, classifier, alpha=0.5):
# aggregative quantifiers have an internal attribute called self.classifier
self.classifier = classifier
self.alpha = alpha
# since this method is of type aggregative, we can simply implement the method aggregation_fit, which
# assumes the classifier has already been fitted properly and the predictions for the training set required
# to train the aggregation function have been properly generated (i.e., on a validation split, or using a
# k-fold cross validation strategy). What remains ahead is to learn an aggregation function. In our case
# this amounts to doing... nothing, since our method was pretty basic. BinaryQuantifier also add some
# basic functionality for checking binary consistency.
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
pass
# since this method is of type aggregative, we can simply implement the method aggregate (i.e., we should
# only describe what to do with the classifier predictions --which in this case are posterior probabilities
# because we are inheriting from the "Soft" subtype). This comes down to:
def aggregate(self, classif_predictions: np.ndarray):
# the posterior probabilities have already been generated by the quantify method; we only need to
# specify what to do with them
@ -38,31 +84,68 @@ class MyQuantifier(AggregativeSoftQuantifier, BinaryQuantifier):
return np.asarray([neg_prev, pos_prev])
# a small example using these two implementations of our method
if __name__ == '__main__':
qp.environ['SAMPLE_SIZE'] = 100
# define an instance of our custom quantifier
quantifier = MyQuantifier(LogisticRegression(), alpha=0.5)
qp.environ['SAMPLE_SIZE'] = 250
# load the IMDb dataset
train, test = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5).train_test
train, val = train.split_stratified(train_prop=0.75) # let's create a validation set for optimizing hyperparams
# model selection
# let us assume we want to explore our hyperparameter alpha along with one hyperparameter of the classifier
train, val = train.split_stratified(train_prop=0.75)
param_grid = {
'alpha': np.linspace(0, 1, 11), # quantifier-dependent hyperparameter
'classifier__C': np.logspace(-2, 2, 5) # classifier-dependent hyperparameter
}
quantifier = GridSearchQ(quantifier, param_grid, protocol=APP(val), n_jobs=-1, verbose=True).fit(train)
def test_implementation(quantifier):
class_name = quantifier.__class__.__name__
print(f'\ntesting implementation {class_name}...')
# model selection
# let us assume we want to explore our hyperparameter alpha along with one hyperparameter of the classifier
tinit = time()
param_grid = {
'alpha': np.linspace(0, 1, 11), # quantifier-dependent hyperparameter
'classifier__C': np.logspace(-2, 2, 5) # classifier-dependent hyperparameter
}
gridsearch = GridSearchQ(quantifier, param_grid, protocol=APP(val), n_jobs=-1, verbose=False).fit(train)
t_modsel = time() - tinit
print(f'\tmodel selection took {t_modsel:.2f}s', flush=True)
# evaluation
mae = qp.evaluation.evaluate(quantifier, protocol=APP(test), error_metric='mae')
# evaluation
optimized_model = gridsearch.best_model_
mae = qp.evaluation.evaluate(
optimized_model,
protocol=APP(test, repeats=5000, sanity_check=None), # disable the check, we want to generate many tests!
error_metric='mae',
verbose=True)
print(f'MAE = {mae:.4f}')
t_eval = time() - t_modsel - tinit
print(f'\tevaluation took {t_eval:.2f}s [MAE = {mae:.4f}]')
# final remarks: this method is only for demonstration purposes and makes little sense in general. The method relies
# define an instance of our custom quantifier and test it!
quantifier = MyQuantifier(LogisticRegression(), alpha=0.5)
test_implementation(quantifier)
# define an instance of our custom quantifier, with the second implementation, and test it!
quantifier = MyAggregativeSoftQuantifier(LogisticRegression(), alpha=0.5)
test_implementation(quantifier)
# the output should look like this:
"""
testing implementation MyQuantifier...
model selection took 12.86s
predicting: 100%|| 105000/105000 [00:22<00:00, 4626.30it/s]
evaluation took 22.75s [MAE = 0.0630]
testing implementation MyAggregativeSoftQuantifier...
model selection took 3.10s
speeding up the prediction for the aggregative quantifier, total classifications 25000 instead of 26250000
predicting: 100%|| 105000/105000 [00:04<00:00, 22779.62it/s]
evaluation took 4.66s [MAE = 0.0630]
"""
# Note that the first implementation is much slower, both in terms of grid-search optimization and in terms of
# evaluation. The reason why is that QuaPy is highly optimized for aggregative quantifiers (by far, the most
# popular type of quantification methods), thus significantly speeding up model selection and test routines.
# Furthermore, it is simpler to extend an aggregation type since QuaPy implements boilerplate functions for you.
# Final remarks: this method is only for demonstration purposes and makes little sense in general. The method relies
# on an hyperparameter alpha for binarizing the posterior probabilities. A much better way for fulfilling this
# goal would be to calibrate the classifier (LogisticRegression is already reasonably well calibrated) and then
# simply cut at 0.5.