1
0
Fork 0
QuaPy/quapy/method/base.py

202 lines
5.8 KiB
Python

from abc import ABCMeta, abstractmethod
from quapy.data import LabelledCollection
# Base Quantifier abstract class
# ------------------------------------
class BaseQuantifier(metaclass=ABCMeta):
"""
Abstract Quantifier. A quantifier is defined as an object of a class that implements the method :meth:`fit` on
:class:`quapy.data.base.LabelledCollection`, the method :meth:`quantify`, and the :meth:`set_params` and
:meth:`get_params` for model selection (see :meth:`quapy.model_selection.GridSearchQ`)
"""
@abstractmethod
def fit(self, data: LabelledCollection):
"""
Trains a quantifier.
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:return: self
"""
...
@abstractmethod
def quantify(self, instances):
"""
Generate class prevalence estimates for the sample's instances
:param instances: array-like
:return: `np.ndarray` of shape `(self.n_classes_,)` with class prevalence estimates.
"""
...
@abstractmethod
def set_params(self, **parameters):
"""
Set the parameters of the quantifier.
:param parameters: dictionary of param-value pairs
"""
...
@abstractmethod
def get_params(self, deep=True):
"""
Return the current parameters of the quantifier.
:param deep: for compatibility with sklearn
:return: a dictionary of param-value pairs
"""
...
@property
@abstractmethod
def classes_(self):
"""
Class labels, in the same order in which class prevalence values are to be computed.
:return: array-like
"""
...
@property
def n_classes(self):
"""
Returns the number of classes
:return: integer
"""
return len(self.classes_)
# these methods allows meta-learners to reimplement the decision based on their constituents, and not
# based on class structure
@property
def binary(self):
"""
Indicates whether the quantifier is binary or not.
:return: False (to be overridden)
"""
return False
@property
def aggregative(self):
"""
Indicates whether the quantifier is of type aggregative or not
:return: False (to be overridden)
"""
return False
@property
def probabilistic(self):
"""
Indicates whether the quantifier is of type probabilistic or not
:return: False (to be overridden)
"""
return False
class BinaryQuantifier(BaseQuantifier):
"""
Abstract class of binary quantifiers, i.e., quantifiers estimating class prevalence values for only two classes
(typically, to be interpreted as one class and its complement).
"""
def _check_binary(self, data: LabelledCollection, quantifier_name):
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
@property
def binary(self):
"""
Informs that the quantifier is binary
:return: True
"""
return True
def isbinary(model:BaseQuantifier):
"""
Alias for property `binary`
:param model: the model
:return: True if the model is binary, False otherwise
"""
return model.binary
def isaggregative(model:BaseQuantifier):
"""
Alias for property `aggregative`
:param model: the model
:return: True if the model is aggregative, False otherwise
"""
return model.aggregative
def isprobabilistic(model:BaseQuantifier):
"""
Alias for property `probabilistic`
:param model: the model
:return: True if the model is probabilistic, False otherwise
"""
return model.probabilistic
# class OneVsAll:
# """
# Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
# quantifier for each class, and then l1-normalizes the outputs so that the class prevelences sum up to 1.
# """
#
# def __init__(self, binary_method, n_jobs=-1):
# self.binary_method = binary_method
# self.n_jobs = n_jobs
#
# def fit(self, data: LabelledCollection, **kwargs):
# assert not data.binary, f'{self.__class__.__name__} expect non-binary data'
# assert isinstance(self.binary_method, BaseQuantifier), f'{self.binary_method} does not seem to be a Quantifier'
# self.class_method = {c: deepcopy(self.binary_method) for c in data.classes_}
# Parallel(n_jobs=self.n_jobs, backend='threading')(
# delayed(self._delayed_binary_fit)(c, self.class_method, data, **kwargs) for c in data.classes_
# )
# return self
#
# def quantify(self, X, *args):
# prevalences = np.asarray(
# Parallel(n_jobs=self.n_jobs, backend='threading')(
# delayed(self._delayed_binary_predict)(c, self.class_method, X) for c in self.classes
# )
# )
# return F.normalize_prevalence(prevalences)
#
# @property
# def classes(self):
# return sorted(self.class_method.keys())
#
# def set_params(self, **parameters):
# self.binary_method.set_params(**parameters)
#
# def get_params(self, deep=True):
# return self.binary_method.get_params()
#
# def _delayed_binary_predict(self, c, learners, X):
# return learners[c].quantify(X)[:,1] # the mean is the estimation for the positive class prevalence
#
# def _delayed_binary_fit(self, c, learners, data, **kwargs):
# bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
# learners[c].fit(bindata, **kwargs)