forked from moreo/QuaPy
added MedianEstimator quantifier
This commit is contained in:
parent
66ad7295df
commit
daca2bd1cb
|
@ -1,13 +1,18 @@
|
|||
Change Log 0.1.8
|
||||
----------------
|
||||
|
||||
- Added HDx and DistributionMatchingX to non-aggregative quantifiers (see also the new example "comparing_HDy_HDx.py")
|
||||
- New UCI multiclass datasets added (thanks to Pablo González). The 5 UCI multiclass datasets are those corresponding
|
||||
to the following criteria:
|
||||
- >1000 instances
|
||||
- >2 classes
|
||||
- classification datasets
|
||||
- Python API available
|
||||
- Added NAE, NRAE
|
||||
- New IFCB (plankton) dataset added. See fetch_IFCB.
|
||||
- Added new evaluation measures NAE, NRAE
|
||||
- Added new meta method "MedianEstimator"; an ensemble of binary base quantifiers that receives as input a dictionary
|
||||
of hyperparameters that will explore exhaustively, fitting and generating predictions for each combination of
|
||||
hyperparameters, and that returns, as the prevalence estimates, the median across all predictions.
|
||||
|
||||
Change Log 0.1.7
|
||||
----------------
|
||||
|
|
|
@ -11,7 +11,7 @@ from . import util
|
|||
from . import model_selection
|
||||
from . import classification
|
||||
|
||||
__version__ = '0.1.7'
|
||||
__version__ = '0.1.8'
|
||||
|
||||
environ = {
|
||||
'SAMPLE_SIZE': None,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from . import aggregative
|
||||
from . import base
|
||||
from . import meta
|
||||
from . import aggregative
|
||||
from . import non_aggregative
|
||||
from . import meta
|
||||
|
||||
AGGREGATIVE_METHODS = {
|
||||
aggregative.CC,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import itertools
|
||||
from copy import deepcopy
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
|
@ -10,13 +11,14 @@ import quapy as qp
|
|||
from quapy import functional as F
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
||||
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ
|
||||
|
||||
try:
|
||||
from . import neural
|
||||
except ModuleNotFoundError:
|
||||
neural = None
|
||||
from .base import BaseQuantifier
|
||||
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ
|
||||
|
||||
|
||||
if neural:
|
||||
QuaNet = neural.QuaNetTrainer
|
||||
|
@ -24,6 +26,67 @@ else:
|
|||
QuaNet = "QuaNet is not available due to missing torch package"
|
||||
|
||||
|
||||
class MedianEstimator(BinaryQuantifier):
|
||||
"""
|
||||
This method is a meta-quantifier that returns, as the estimated class prevalence values, the median of the
|
||||
estimation returned by differently (hyper)parameterized base quantifiers.
|
||||
The median of unit-vectors is only guaranteed to be a unit-vector for n=2 dimensions,
|
||||
i.e., in cases of binary quantification.
|
||||
|
||||
:param base_quantifier: the base, binary quantifier
|
||||
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||
:param param_grid: the grid or parameters towards which the median will be computed
|
||||
:param n_jobs: number of parllel workes
|
||||
"""
|
||||
def __init__(self, base_quantifier: BinaryQuantifier, param_grid: dict, random_state=None, n_jobs=None):
|
||||
self.base_quantifier = base_quantifier
|
||||
self.param_grid = param_grid
|
||||
self.random_state = random_state
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return self.base_quantifier.get_params(deep)
|
||||
|
||||
def set_params(self, **params):
|
||||
self.base_quantifier.set_params(**params)
|
||||
|
||||
def _delayed_fit(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
params, training = args
|
||||
model = deepcopy(self.base_quantifier)
|
||||
model.set_params(**params)
|
||||
model.fit(training)
|
||||
return model
|
||||
|
||||
def fit(self, training: LabelledCollection):
|
||||
self._check_binary(training, self.__class__.__name__)
|
||||
params_keys = list(self.param_grid.keys())
|
||||
params_values = list(self.param_grid.values())
|
||||
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit,
|
||||
((params, training) for params in hyper),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
return self
|
||||
|
||||
def _delayed_predict(self, args):
|
||||
model, instances = args
|
||||
return model.quantify(instances)
|
||||
|
||||
def quantify(self, instances):
|
||||
prev_preds = qp.util.parallel(
|
||||
self._delayed_predict,
|
||||
((model, instances) for model in self.models),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
prev_preds = np.asarray(prev_preds)
|
||||
return np.median(prev_preds, axis=0)
|
||||
|
||||
|
||||
|
||||
class Ensemble(BaseQuantifier):
|
||||
VALID_POLICIES = {'ave', 'ptr', 'ds'} | qp.error.QUANTIFICATION_ERROR_NAMES
|
||||
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import numpy
|
||||
import numpy as np
|
||||
import pytest
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.svm import LinearSVC
|
||||
|
||||
import quapy as qp
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.method.base import BinaryQuantifier
|
||||
from quapy.data import Dataset, LabelledCollection
|
||||
from quapy.method import AGGREGATIVE_METHODS, NON_AGGREGATIVE_METHODS
|
||||
from quapy.method.aggregative import ACC, PACC, HDy
|
||||
from quapy.method.meta import Ensemble
|
||||
from quapy.protocol import APP
|
||||
from quapy.method.aggregative import DistributionMatching
|
||||
from quapy.method.meta import MedianEstimator
|
||||
|
||||
datasets = [pytest.param(qp.datasets.fetch_twitter('hcr', pickle=True), id='hcr'),
|
||||
pytest.param(qp.datasets.fetch_UCIDataset('ionosphere'), id='ionosphere')]
|
||||
|
@ -36,7 +39,7 @@ def test_aggregative_methods(dataset: Dataset, aggregative_method, learner):
|
|||
true_prevalences = dataset.test.prevalence()
|
||||
error = qp.error.mae(true_prevalences, estim_prevalences)
|
||||
|
||||
assert type(error) == numpy.float64
|
||||
assert type(error) == np.float64
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset', datasets)
|
||||
|
@ -55,7 +58,7 @@ def test_non_aggregative_methods(dataset: Dataset, non_aggregative_method):
|
|||
true_prevalences = dataset.test.prevalence()
|
||||
error = qp.error.mae(true_prevalences, estim_prevalences)
|
||||
|
||||
assert type(error) == numpy.float64
|
||||
assert type(error) == np.float64
|
||||
|
||||
|
||||
@pytest.mark.parametrize('base_method', AGGREGATIVE_METHODS)
|
||||
|
@ -80,7 +83,7 @@ def test_ensemble_method(base_method, learner, dataset: Dataset, policy):
|
|||
true_prevalences = dataset.test.prevalence()
|
||||
error = qp.error.mae(true_prevalences, estim_prevalences)
|
||||
|
||||
assert type(error) == numpy.float64
|
||||
assert type(error) == np.float64
|
||||
|
||||
|
||||
def test_quanet_method():
|
||||
|
@ -119,7 +122,7 @@ def test_quanet_method():
|
|||
true_prevalences = dataset.test.prevalence()
|
||||
error = qp.error.mae(true_prevalences, estim_prevalences)
|
||||
|
||||
assert type(error) == numpy.float64
|
||||
assert type(error) == np.float64
|
||||
|
||||
|
||||
def test_str_label_names():
|
||||
|
@ -130,32 +133,103 @@ def test_str_label_names():
|
|||
dataset.test.sampling(1000, 0.25, 0.75))
|
||||
qp.data.preprocessing.text2tfidf(dataset, min_df=5, inplace=True)
|
||||
|
||||
numpy.random.seed(0)
|
||||
np.random.seed(0)
|
||||
model.fit(dataset.training)
|
||||
|
||||
int_estim_prevalences = model.quantify(dataset.test.instances)
|
||||
true_prevalences = dataset.test.prevalence()
|
||||
|
||||
error = qp.error.mae(true_prevalences, int_estim_prevalences)
|
||||
assert type(error) == numpy.float64
|
||||
assert type(error) == np.float64
|
||||
|
||||
dataset_str = Dataset(LabelledCollection(dataset.training.instances,
|
||||
['one' if label == 1 else 'zero' for label in dataset.training.labels]),
|
||||
LabelledCollection(dataset.test.instances,
|
||||
['one' if label == 1 else 'zero' for label in dataset.test.labels]))
|
||||
assert all(dataset_str.training.classes_ == dataset_str.test.classes_), 'wrong indexation'
|
||||
numpy.random.seed(0)
|
||||
np.random.seed(0)
|
||||
model.fit(dataset_str.training)
|
||||
|
||||
str_estim_prevalences = model.quantify(dataset_str.test.instances)
|
||||
true_prevalences = dataset_str.test.prevalence()
|
||||
|
||||
error = qp.error.mae(true_prevalences, str_estim_prevalences)
|
||||
assert type(error) == numpy.float64
|
||||
assert type(error) == np.float64
|
||||
|
||||
print(true_prevalences)
|
||||
print(int_estim_prevalences)
|
||||
print(str_estim_prevalences)
|
||||
|
||||
numpy.testing.assert_almost_equal(int_estim_prevalences[1],
|
||||
np.testing.assert_almost_equal(int_estim_prevalences[1],
|
||||
str_estim_prevalences[list(model.classes_).index('one')])
|
||||
|
||||
# helper
|
||||
def __fit_test(quantifier, train, test):
|
||||
quantifier.fit(train)
|
||||
test_samples = APP(test)
|
||||
true_prevs, estim_prevs = qp.evaluation.prediction(quantifier, test_samples)
|
||||
return qp.error.mae(true_prevs, estim_prevs), estim_prevs
|
||||
|
||||
|
||||
def test_median_meta():
|
||||
"""
|
||||
This test compares the performance of the MedianQuantifier with respect to computing the median of the predictions
|
||||
of a differently parameterized quantifier. We use the DistributionMatching base quantifier and the median is
|
||||
computed across different values of nbins
|
||||
"""
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 100
|
||||
|
||||
# grid of values
|
||||
nbins_grid = list(range(2, 11))
|
||||
|
||||
dataset = 'kindle'
|
||||
train, test = qp.datasets.fetch_reviews(dataset, tfidf=True, min_df=10).train_test
|
||||
prevs = []
|
||||
errors = []
|
||||
for nbins in nbins_grid:
|
||||
with qp.util.temp_seed(0):
|
||||
q = DistributionMatching(LogisticRegression(), nbins=nbins)
|
||||
mae, estim_prevs = __fit_test(q, train, test)
|
||||
prevs.append(estim_prevs)
|
||||
errors.append(mae)
|
||||
print(f'{dataset} DistributionMatching(nbins={nbins}) got MAE {mae:.4f}')
|
||||
prevs = np.asarray(prevs)
|
||||
mae = np.mean(errors)
|
||||
print(f'\tMAE={mae:.4f}')
|
||||
|
||||
q = DistributionMatching(LogisticRegression())
|
||||
q = MedianEstimator(q, param_grid={'nbins': nbins_grid}, random_state=0, n_jobs=-1)
|
||||
median_mae, prev = __fit_test(q, train, test)
|
||||
print(f'\tMAE={median_mae:.4f}')
|
||||
|
||||
np.testing.assert_almost_equal(np.median(prevs, axis=0), prev)
|
||||
assert median_mae < mae, 'the median-based quantifier provided a higher error...'
|
||||
|
||||
|
||||
def test_median_meta_modsel():
|
||||
"""
|
||||
This test checks the median-meta quantifier with model selection
|
||||
"""
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 100
|
||||
|
||||
dataset = 'kindle'
|
||||
train, test = qp.datasets.fetch_reviews(dataset, tfidf=True, min_df=10).train_test
|
||||
train, val = train.split_stratified(random_state=0)
|
||||
|
||||
nbins_grid = [2, 4, 5, 10, 15]
|
||||
|
||||
q = DistributionMatching(LogisticRegression())
|
||||
q = MedianEstimator(q, param_grid={'nbins': nbins_grid}, random_state=0, n_jobs=-1)
|
||||
median_mae, _ = __fit_test(q, train, test)
|
||||
print(f'\tMAE={median_mae:.4f}')
|
||||
|
||||
q = DistributionMatching(LogisticRegression())
|
||||
lr_params = {'classifier__C': np.logspace(-1, 1, 3)}
|
||||
q = MedianEstimator(q, param_grid={'nbins': nbins_grid}, random_state=0, n_jobs=-1)
|
||||
q = GridSearchQ(q, param_grid=lr_params, protocol=APP(val), n_jobs=-1)
|
||||
optimized_median_ave, _ = __fit_test(q, train, test)
|
||||
print(f'\tMAE={optimized_median_ave:.4f}')
|
||||
|
||||
assert optimized_median_ave < median_mae, "the optimized method yielded worse performance..."
|
Loading…
Reference in New Issue