1
0
Fork 0
QuaPy/quapy/method/base.py

110 lines
4.2 KiB
Python
Raw Normal View History

2020-12-03 18:12:28 +01:00
from abc import ABCMeta, abstractmethod
from copy import deepcopy
2023-02-09 19:39:16 +01:00
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator
import quapy as qp
2021-01-15 18:32:32 +01:00
from quapy.data import LabelledCollection
2023-02-09 19:39:16 +01:00
import numpy as np
2020-12-03 18:12:28 +01:00
# Base Quantifier abstract class
# ------------------------------------
class BaseQuantifier(BaseEstimator):
2021-12-15 15:27:43 +01:00
"""
Abstract Quantifier. A quantifier is defined as an object of a class that implements the method :meth:`fit` on
:class:`quapy.data.base.LabelledCollection`, the method :meth:`quantify`, and the :meth:`set_params` and
:meth:`get_params` for model selection (see :meth:`quapy.model_selection.GridSearchQ`)
"""
2020-12-03 18:12:28 +01:00
@abstractmethod
2021-12-15 15:27:43 +01:00
def fit(self, data: LabelledCollection):
"""
Trains a quantifier.
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:return: self
"""
...
2020-12-03 18:12:28 +01:00
@abstractmethod
2021-12-15 15:27:43 +01:00
def quantify(self, instances):
"""
Generate class prevalence estimates for the sample's instances
:param instances: array-like
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
2021-12-15 15:27:43 +01:00
"""
...
2020-12-03 18:12:28 +01:00
class BinaryQuantifier(BaseQuantifier):
2021-12-15 15:27:43 +01:00
"""
Abstract class of binary quantifiers, i.e., quantifiers estimating class prevalence values for only two classes
(typically, to be interpreted as one class and its complement).
"""
def _check_binary(self, data: LabelledCollection, quantifier_name):
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
2023-02-09 19:39:16 +01:00
class OneVsAll:
pass
2023-02-13 12:01:52 +01:00
def newOneVsAll(binary_quantifier, n_jobs=None):
2023-02-09 19:39:16 +01:00
assert isinstance(binary_quantifier, BaseQuantifier), \
f'{binary_quantifier} does not seem to be a Quantifier'
if isinstance(binary_quantifier, qp.method.aggregative.AggregativeQuantifier):
2023-02-10 19:02:17 +01:00
return qp.method.aggregative.OneVsAllAggregative(binary_quantifier, n_jobs)
2023-02-09 19:39:16 +01:00
else:
2023-02-10 19:02:17 +01:00
return OneVsAllGeneric(binary_quantifier, n_jobs)
2023-02-09 19:39:16 +01:00
class OneVsAllGeneric(OneVsAll, BaseQuantifier):
"""
Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
quantifier for each class, and then l1-normalizes the outputs so that the class prevelence values sum up to 1.
"""
2021-01-07 17:58:48 +01:00
2023-02-10 19:02:17 +01:00
def __init__(self, binary_quantifier, n_jobs=None):
assert isinstance(binary_quantifier, BaseQuantifier), \
f'{binary_quantifier} does not seem to be a Quantifier'
2023-02-09 19:39:16 +01:00
if isinstance(binary_quantifier, qp.method.aggregative.AggregativeQuantifier):
print('[warning] the quantifier seems to be an instance of qp.method.aggregative.AggregativeQuantifier; '
f'you might prefer instantiating {qp.method.aggregative.OneVsAllAggregative.__name__}')
self.binary_quantifier = binary_quantifier
self.n_jobs = qp._get_njobs(n_jobs)
2023-02-09 19:39:16 +01:00
def fit(self, data: LabelledCollection, fit_classifier=True):
assert not data.binary, f'{self.__class__.__name__} expect non-binary data'
assert fit_classifier == True, 'fit_classifier must be True'
self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
self._parallel(self._delayed_binary_fit, data)
return self
2023-02-09 19:39:16 +01:00
def _parallel(self, func, *args, **kwargs):
return np.asarray(
2023-02-10 19:02:17 +01:00
Parallel(n_jobs=self.n_jobs, backend='threading')(
2023-02-09 19:39:16 +01:00
delayed(func)(c, *args, **kwargs) for c in self.classes_
)
)
2023-02-09 19:39:16 +01:00
def quantify(self, instances):
prevalences = self._parallel(self._delayed_binary_predict, instances)
return qp.functional.normalize_prevalence(prevalences)
2023-02-09 19:39:16 +01:00
@property
def classes_(self):
return sorted(self.dict_binary_quantifiers.keys())
2023-02-09 19:39:16 +01:00
def _delayed_binary_predict(self, c, X):
return self.dict_binary_quantifiers[c].quantify(X)[1]
2023-02-09 19:39:16 +01:00
def _delayed_binary_fit(self, c, data):
bindata = LabelledCollection(data.instances, data.labels == c, classes=[False, True])
self.dict_binary_quantifiers[c].fit(bindata)