Added classes_ property to all quantifiers.

This commit is contained in:
Andrea Esuli 2021-05-04 17:09:13 +02:00
parent 70a3d4bd0f
commit bfbfe08116
7 changed files with 29 additions and 10 deletions

View File

@ -17,14 +17,12 @@ Current issues:
In binary quantification (hp, kindle, imdb) we used F1 in the minority class (which in kindle and hp happens to be the In binary quantification (hp, kindle, imdb) we used F1 in the minority class (which in kindle and hp happens to be the
negative class). This is not covered in this new implementation, in which the binary case is not treated as such, but as negative class). This is not covered in this new implementation, in which the binary case is not treated as such, but as
an instance of single-label with 2 labels. Check an instance of single-label with 2 labels. Check
Add classnames to LabelledCollection? This should improve visualization of reports
Add automatic reindex of class labels in LabelledCollection (currently, class indexes should be ordered and with no gaps) Add automatic reindex of class labels in LabelledCollection (currently, class indexes should be ordered and with no gaps)
OVR I believe is currently tied to aggregative methods. We should provide a general interface also for general quantifiers OVR I believe is currently tied to aggregative methods. We should provide a general interface also for general quantifiers
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
Improvements: Improvements:
========================================== ==========================================
Clarify whether QuaNet is an aggregative method or not.
Explore the hyperparameter "number of bins" in HDy Explore the hyperparameter "number of bins" in HDy
Rename EMQ to SLD ? Rename EMQ to SLD ?
Parallelize the kFCV in ACC and PACC? Parallelize the kFCV in ACC and PACC?

View File

@ -53,10 +53,10 @@ class AggregativeQuantifier(BaseQuantifier):
@property @property
def n_classes(self): def n_classes(self):
return len(self.classes) return len(self.classes_)
@property @property
def classes(self): def classes_(self):
return self.learner.classes_ return self.learner.classes_
@property @property

View File

@ -19,6 +19,10 @@ class BaseQuantifier(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def get_params(self, deep=True): ... def get_params(self, deep=True): ...
@abstractmethod
@property
def classes_(self): ...
# these methods allows meta-learners to reimplement the decision based on their constituents, and not # these methods allows meta-learners to reimplement the decision based on their constituents, and not
# based on class structure # based on class structure
@property @property

View File

@ -186,6 +186,10 @@ class Ensemble(BaseQuantifier):
order = np.argsort(dist) order = np.argsort(dist)
return _select_k(predictions, order, k=self.red_size) return _select_k(predictions, order, k=self.red_size)
@property
def classes_(self):
return self.base_quantifier.classes_
@property @property
def binary(self): def binary(self):
return self.base_quantifier.binary return self.base_quantifier.binary

View File

@ -58,6 +58,7 @@ class QuaNetTrainer(BaseQuantifier):
self.device = torch.device(device) self.device = torch.device(device)
self.__check_params_colision(self.quanet_params, self.learner.get_params()) self.__check_params_colision(self.quanet_params, self.learner.get_params())
self._classes_ = None
def fit(self, data: LabelledCollection, fit_learner=True): def fit(self, data: LabelledCollection, fit_learner=True):
""" """
@ -67,6 +68,7 @@ class QuaNetTrainer(BaseQuantifier):
:param fit_learner: if true, trains the classifier on a split containing 40% of the data :param fit_learner: if true, trains the classifier on a split containing 40% of the data
:return: self :return: self
""" """
self._classes_ = data.classes_
classifier_data, unused_data = data.split_stratified(0.4) classifier_data, unused_data = data.split_stratified(0.4)
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20% train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
@ -256,6 +258,10 @@ class QuaNetTrainer(BaseQuantifier):
import shutil import shutil
shutil.rmtree(self.checkpointdir, ignore_errors=True) shutil.rmtree(self.checkpointdir, ignore_errors=True)
@property
def classes_(self):
return self._classes_
def mae_loss(output, target): def mae_loss(output, target):
return torch.mean(torch.abs(output - target)) return torch.mean(torch.abs(output - target))

View File

@ -2,18 +2,22 @@ from quapy.data import LabelledCollection
from .base import BaseQuantifier from .base import BaseQuantifier
class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier): class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass self._classes_ = None
def fit(self, data: LabelledCollection, *args): def fit(self, data: LabelledCollection, *args):
self._classes_ = data.classes_
self.estimated_prevalence = data.prevalence() self.estimated_prevalence = data.prevalence()
def quantify(self, documents, *args): def quantify(self, documents, *args):
return self.estimated_prevalence return self.estimated_prevalence
@property
def classes_(self):
return self._classes_
def get_params(self): def get_params(self):
pass pass

View File

@ -4,7 +4,6 @@ from copy import deepcopy
from typing import Union, Callable from typing import Union, Callable
import quapy as qp import quapy as qp
import quapy.functional as F
from quapy.data.base import LabelledCollection from quapy.data.base import LabelledCollection
from quapy.evaluation import artificial_sampling_prediction from quapy.evaluation import artificial_sampling_prediction
from quapy.method.aggregative import BaseQuantifier from quapy.method.aggregative import BaseQuantifier
@ -80,7 +79,7 @@ class GridSearchQ(BaseQuantifier):
return training, validation return training, validation
elif isinstance(validation, float): elif isinstance(validation, float):
assert 0. < validation < 1., 'validation proportion should be in (0,1)' assert 0. < validation < 1., 'validation proportion should be in (0,1)'
training, validation = training.split_stratified(train_prop=1-validation) training, validation = training.split_stratified(train_prop=1 - validation)
return training, validation return training, validation
else: else:
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the' raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
@ -97,7 +96,7 @@ class GridSearchQ(BaseQuantifier):
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n' raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}') f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float]=None): def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
""" """
:param training: the training set on which to optimize the hyperparameters :param training: the training set on which to optimize the hyperparameters
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or :param val_split: either a LabelledCollection on which to test the performance of the different settings, or
@ -118,6 +117,7 @@ class GridSearchQ(BaseQuantifier):
def handler(signum, frame): def handler(signum, frame):
self.sout('timeout reached') self.sout('timeout reached')
raise TimeoutError() raise TimeoutError()
signal.signal(signal.SIGALRM, handler) signal.signal(signal.SIGALRM, handler)
self.sout(f'starting optimization with n_jobs={n_jobs}') self.sout(f'starting optimization with n_jobs={n_jobs}')
@ -175,6 +175,10 @@ class GridSearchQ(BaseQuantifier):
def quantify(self, instances): def quantify(self, instances):
return self.best_model_.quantify(instances) return self.best_model_.quantify(instances)
@property
def classes_(self):
return self.best_model_.classes_
def set_params(self, **parameters): def set_params(self, **parameters):
self.param_grid = parameters self.param_grid = parameters
@ -185,4 +189,3 @@ class GridSearchQ(BaseQuantifier):
if hasattr(self, 'best_model_'): if hasattr(self, 'best_model_'):
return self.best_model_ return self.best_model_
raise ValueError('best_model called before fit') raise ValueError('best_model called before fit')