forked from moreo/QuaPy
Added classes_ property to all quantifiers.
This commit is contained in:
parent
70a3d4bd0f
commit
bfbfe08116
2
TODO.txt
2
TODO.txt
|
@ -17,14 +17,12 @@ Current issues:
|
|||
In binary quantification (hp, kindle, imdb) we used F1 in the minority class (which in kindle and hp happens to be the
|
||||
negative class). This is not covered in this new implementation, in which the binary case is not treated as such, but as
|
||||
an instance of single-label with 2 labels. Check
|
||||
Add classnames to LabelledCollection? This should improve visualization of reports
|
||||
Add automatic reindex of class labels in LabelledCollection (currently, class indexes should be ordered and with no gaps)
|
||||
OVR I believe is currently tied to aggregative methods. We should provide a general interface also for general quantifiers
|
||||
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
|
||||
|
||||
Improvements:
|
||||
==========================================
|
||||
Clarify whether QuaNet is an aggregative method or not.
|
||||
Explore the hyperparameter "number of bins" in HDy
|
||||
Rename EMQ to SLD ?
|
||||
Parallelize the kFCV in ACC and PACC?
|
||||
|
|
|
@ -53,10 +53,10 @@ class AggregativeQuantifier(BaseQuantifier):
|
|||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return len(self.classes)
|
||||
return len(self.classes_)
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
def classes_(self):
|
||||
return self.learner.classes_
|
||||
|
||||
@property
|
||||
|
|
|
@ -19,6 +19,10 @@ class BaseQuantifier(metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def get_params(self, deep=True): ...
|
||||
|
||||
@abstractmethod
|
||||
@property
|
||||
def classes_(self): ...
|
||||
|
||||
# these methods allows meta-learners to reimplement the decision based on their constituents, and not
|
||||
# based on class structure
|
||||
@property
|
||||
|
|
|
@ -186,6 +186,10 @@ class Ensemble(BaseQuantifier):
|
|||
order = np.argsort(dist)
|
||||
return _select_k(predictions, order, k=self.red_size)
|
||||
|
||||
@property
|
||||
def classes_(self):
|
||||
return self.base_quantifier.classes_
|
||||
|
||||
@property
|
||||
def binary(self):
|
||||
return self.base_quantifier.binary
|
||||
|
|
|
@ -58,6 +58,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
self.device = torch.device(device)
|
||||
|
||||
self.__check_params_colision(self.quanet_params, self.learner.get_params())
|
||||
self._classes_ = None
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
"""
|
||||
|
@ -67,6 +68,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
:param fit_learner: if true, trains the classifier on a split containing 40% of the data
|
||||
:return: self
|
||||
"""
|
||||
self._classes_ = data.classes_
|
||||
classifier_data, unused_data = data.split_stratified(0.4)
|
||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||
|
||||
|
@ -256,6 +258,10 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
import shutil
|
||||
shutil.rmtree(self.checkpointdir, ignore_errors=True)
|
||||
|
||||
@property
|
||||
def classes_(self):
|
||||
return self._classes_
|
||||
|
||||
|
||||
def mae_loss(output, target):
|
||||
return torch.mean(torch.abs(output - target))
|
||||
|
|
|
@ -2,18 +2,22 @@ from quapy.data import LabelledCollection
|
|||
from .base import BaseQuantifier
|
||||
|
||||
|
||||
|
||||
class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
self._classes_ = None
|
||||
|
||||
def fit(self, data: LabelledCollection, *args):
|
||||
self._classes_ = data.classes_
|
||||
self.estimated_prevalence = data.prevalence()
|
||||
|
||||
def quantify(self, documents, *args):
|
||||
return self.estimated_prevalence
|
||||
|
||||
@property
|
||||
def classes_(self):
|
||||
return self._classes_
|
||||
|
||||
def get_params(self):
|
||||
pass
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ from copy import deepcopy
|
|||
from typing import Union, Callable
|
||||
|
||||
import quapy as qp
|
||||
import quapy.functional as F
|
||||
from quapy.data.base import LabelledCollection
|
||||
from quapy.evaluation import artificial_sampling_prediction
|
||||
from quapy.method.aggregative import BaseQuantifier
|
||||
|
@ -80,7 +79,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
return training, validation
|
||||
elif isinstance(validation, float):
|
||||
assert 0. < validation < 1., 'validation proportion should be in (0,1)'
|
||||
training, validation = training.split_stratified(train_prop=1-validation)
|
||||
training, validation = training.split_stratified(train_prop=1 - validation)
|
||||
return training, validation
|
||||
else:
|
||||
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
|
||||
|
@ -97,7 +96,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||
|
||||
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float]=None):
|
||||
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
|
||||
"""
|
||||
:param training: the training set on which to optimize the hyperparameters
|
||||
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
||||
|
@ -118,6 +117,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
def handler(signum, frame):
|
||||
self.sout('timeout reached')
|
||||
raise TimeoutError()
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
|
||||
self.sout(f'starting optimization with n_jobs={n_jobs}')
|
||||
|
@ -175,6 +175,10 @@ class GridSearchQ(BaseQuantifier):
|
|||
def quantify(self, instances):
|
||||
return self.best_model_.quantify(instances)
|
||||
|
||||
@property
|
||||
def classes_(self):
|
||||
return self.best_model_.classes_
|
||||
|
||||
def set_params(self, **parameters):
|
||||
self.param_grid = parameters
|
||||
|
||||
|
@ -185,4 +189,3 @@ class GridSearchQ(BaseQuantifier):
|
|||
if hasattr(self, 'best_model_'):
|
||||
return self.best_model_
|
||||
raise ValueError('best_model called before fit')
|
||||
|
||||
|
|
Loading…
Reference in New Issue