Added classes_ property to all quantifiers.

This commit is contained in:
Andrea Esuli 2021-05-04 17:09:13 +02:00
parent 70a3d4bd0f
commit bfbfe08116
7 changed files with 29 additions and 10 deletions

View File

@ -17,14 +17,12 @@ Current issues:
In binary quantification (hp, kindle, imdb) we used F1 in the minority class (which in kindle and hp happens to be the
negative class). This is not covered in this new implementation, in which the binary case is not treated as such, but as
an instance of single-label with 2 labels. Check
Add classnames to LabelledCollection? This should improve visualization of reports
Add automatic reindex of class labels in LabelledCollection (currently, class indexes should be ordered and with no gaps)
OVR I believe is currently tied to aggregative methods. We should provide a general interface also for general quantifiers
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
Improvements:
==========================================
Clarify whether QuaNet is an aggregative method or not.
Explore the hyperparameter "number of bins" in HDy
Rename EMQ to SLD ?
Parallelize the kFCV in ACC and PACC?

View File

@ -53,10 +53,10 @@ class AggregativeQuantifier(BaseQuantifier):
@property
def n_classes(self):
return len(self.classes)
return len(self.classes_)
@property
def classes(self):
def classes_(self):
return self.learner.classes_
@property

View File

@ -19,6 +19,10 @@ class BaseQuantifier(metaclass=ABCMeta):
@abstractmethod
def get_params(self, deep=True): ...
@abstractmethod
@property
def classes_(self): ...
# these methods allows meta-learners to reimplement the decision based on their constituents, and not
# based on class structure
@property

View File

@ -186,6 +186,10 @@ class Ensemble(BaseQuantifier):
order = np.argsort(dist)
return _select_k(predictions, order, k=self.red_size)
@property
def classes_(self):
return self.base_quantifier.classes_
@property
def binary(self):
return self.base_quantifier.binary

View File

@ -58,6 +58,7 @@ class QuaNetTrainer(BaseQuantifier):
self.device = torch.device(device)
self.__check_params_colision(self.quanet_params, self.learner.get_params())
self._classes_ = None
def fit(self, data: LabelledCollection, fit_learner=True):
"""
@ -67,6 +68,7 @@ class QuaNetTrainer(BaseQuantifier):
:param fit_learner: if true, trains the classifier on a split containing 40% of the data
:return: self
"""
self._classes_ = data.classes_
classifier_data, unused_data = data.split_stratified(0.4)
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
@ -256,6 +258,10 @@ class QuaNetTrainer(BaseQuantifier):
import shutil
shutil.rmtree(self.checkpointdir, ignore_errors=True)
@property
def classes_(self):
return self._classes_
def mae_loss(output, target):
return torch.mean(torch.abs(output - target))

View File

@ -2,18 +2,22 @@ from quapy.data import LabelledCollection
from .base import BaseQuantifier
class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
def __init__(self, **kwargs):
pass
self._classes_ = None
def fit(self, data: LabelledCollection, *args):
self._classes_ = data.classes_
self.estimated_prevalence = data.prevalence()
def quantify(self, documents, *args):
return self.estimated_prevalence
@property
def classes_(self):
return self._classes_
def get_params(self):
pass

View File

@ -4,7 +4,6 @@ from copy import deepcopy
from typing import Union, Callable
import quapy as qp
import quapy.functional as F
from quapy.data.base import LabelledCollection
from quapy.evaluation import artificial_sampling_prediction
from quapy.method.aggregative import BaseQuantifier
@ -80,7 +79,7 @@ class GridSearchQ(BaseQuantifier):
return training, validation
elif isinstance(validation, float):
assert 0. < validation < 1., 'validation proportion should be in (0,1)'
training, validation = training.split_stratified(train_prop=1-validation)
training, validation = training.split_stratified(train_prop=1 - validation)
return training, validation
else:
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
@ -97,7 +96,7 @@ class GridSearchQ(BaseQuantifier):
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float]=None):
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
"""
:param training: the training set on which to optimize the hyperparameters
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
@ -118,6 +117,7 @@ class GridSearchQ(BaseQuantifier):
def handler(signum, frame):
self.sout('timeout reached')
raise TimeoutError()
signal.signal(signal.SIGALRM, handler)
self.sout(f'starting optimization with n_jobs={n_jobs}')
@ -175,6 +175,10 @@ class GridSearchQ(BaseQuantifier):
def quantify(self, instances):
return self.best_model_.quantify(instances)
@property
def classes_(self):
return self.best_model_.classes_
def set_params(self, **parameters):
self.param_grid = parameters
@ -185,4 +189,3 @@ class GridSearchQ(BaseQuantifier):
if hasattr(self, 'best_model_'):
return self.best_model_
raise ValueError('best_model called before fit')