forked from moreo/QuaPy
Added classes_ property to all quantifiers.
This commit is contained in:
parent
70a3d4bd0f
commit
bfbfe08116
2
TODO.txt
2
TODO.txt
|
@ -17,14 +17,12 @@ Current issues:
|
||||||
In binary quantification (hp, kindle, imdb) we used F1 in the minority class (which in kindle and hp happens to be the
|
In binary quantification (hp, kindle, imdb) we used F1 in the minority class (which in kindle and hp happens to be the
|
||||||
negative class). This is not covered in this new implementation, in which the binary case is not treated as such, but as
|
negative class). This is not covered in this new implementation, in which the binary case is not treated as such, but as
|
||||||
an instance of single-label with 2 labels. Check
|
an instance of single-label with 2 labels. Check
|
||||||
Add classnames to LabelledCollection? This should improve visualization of reports
|
|
||||||
Add automatic reindex of class labels in LabelledCollection (currently, class indexes should be ordered and with no gaps)
|
Add automatic reindex of class labels in LabelledCollection (currently, class indexes should be ordered and with no gaps)
|
||||||
OVR I believe is currently tied to aggregative methods. We should provide a general interface also for general quantifiers
|
OVR I believe is currently tied to aggregative methods. We should provide a general interface also for general quantifiers
|
||||||
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
|
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
|
||||||
|
|
||||||
Improvements:
|
Improvements:
|
||||||
==========================================
|
==========================================
|
||||||
Clarify whether QuaNet is an aggregative method or not.
|
|
||||||
Explore the hyperparameter "number of bins" in HDy
|
Explore the hyperparameter "number of bins" in HDy
|
||||||
Rename EMQ to SLD ?
|
Rename EMQ to SLD ?
|
||||||
Parallelize the kFCV in ACC and PACC?
|
Parallelize the kFCV in ACC and PACC?
|
||||||
|
|
|
@ -53,10 +53,10 @@ class AggregativeQuantifier(BaseQuantifier):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_classes(self):
|
def n_classes(self):
|
||||||
return len(self.classes)
|
return len(self.classes_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def classes(self):
|
def classes_(self):
|
||||||
return self.learner.classes_
|
return self.learner.classes_
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -19,6 +19,10 @@ class BaseQuantifier(metaclass=ABCMeta):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_params(self, deep=True): ...
|
def get_params(self, deep=True): ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
@property
|
||||||
|
def classes_(self): ...
|
||||||
|
|
||||||
# these methods allows meta-learners to reimplement the decision based on their constituents, and not
|
# these methods allows meta-learners to reimplement the decision based on their constituents, and not
|
||||||
# based on class structure
|
# based on class structure
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -186,6 +186,10 @@ class Ensemble(BaseQuantifier):
|
||||||
order = np.argsort(dist)
|
order = np.argsort(dist)
|
||||||
return _select_k(predictions, order, k=self.red_size)
|
return _select_k(predictions, order, k=self.red_size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes_(self):
|
||||||
|
return self.base_quantifier.classes_
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def binary(self):
|
def binary(self):
|
||||||
return self.base_quantifier.binary
|
return self.base_quantifier.binary
|
||||||
|
|
|
@ -58,6 +58,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
|
||||||
self.__check_params_colision(self.quanet_params, self.learner.get_params())
|
self.__check_params_colision(self.quanet_params, self.learner.get_params())
|
||||||
|
self._classes_ = None
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||||
"""
|
"""
|
||||||
|
@ -67,6 +68,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
:param fit_learner: if true, trains the classifier on a split containing 40% of the data
|
:param fit_learner: if true, trains the classifier on a split containing 40% of the data
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
|
self._classes_ = data.classes_
|
||||||
classifier_data, unused_data = data.split_stratified(0.4)
|
classifier_data, unused_data = data.split_stratified(0.4)
|
||||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||||
|
|
||||||
|
@ -256,6 +258,10 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(self.checkpointdir, ignore_errors=True)
|
shutil.rmtree(self.checkpointdir, ignore_errors=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes_(self):
|
||||||
|
return self._classes_
|
||||||
|
|
||||||
|
|
||||||
def mae_loss(output, target):
|
def mae_loss(output, target):
|
||||||
return torch.mean(torch.abs(output - target))
|
return torch.mean(torch.abs(output - target))
|
||||||
|
|
|
@ -2,18 +2,22 @@ from quapy.data import LabelledCollection
|
||||||
from .base import BaseQuantifier
|
from .base import BaseQuantifier
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
|
class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
pass
|
self._classes_ = None
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, *args):
|
def fit(self, data: LabelledCollection, *args):
|
||||||
|
self._classes_ = data.classes_
|
||||||
self.estimated_prevalence = data.prevalence()
|
self.estimated_prevalence = data.prevalence()
|
||||||
|
|
||||||
def quantify(self, documents, *args):
|
def quantify(self, documents, *args):
|
||||||
return self.estimated_prevalence
|
return self.estimated_prevalence
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes_(self):
|
||||||
|
return self._classes_
|
||||||
|
|
||||||
def get_params(self):
|
def get_params(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ from copy import deepcopy
|
||||||
from typing import Union, Callable
|
from typing import Union, Callable
|
||||||
|
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
import quapy.functional as F
|
|
||||||
from quapy.data.base import LabelledCollection
|
from quapy.data.base import LabelledCollection
|
||||||
from quapy.evaluation import artificial_sampling_prediction
|
from quapy.evaluation import artificial_sampling_prediction
|
||||||
from quapy.method.aggregative import BaseQuantifier
|
from quapy.method.aggregative import BaseQuantifier
|
||||||
|
@ -80,7 +79,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
return training, validation
|
return training, validation
|
||||||
elif isinstance(validation, float):
|
elif isinstance(validation, float):
|
||||||
assert 0. < validation < 1., 'validation proportion should be in (0,1)'
|
assert 0. < validation < 1., 'validation proportion should be in (0,1)'
|
||||||
training, validation = training.split_stratified(train_prop=1-validation)
|
training, validation = training.split_stratified(train_prop=1 - validation)
|
||||||
return training, validation
|
return training, validation
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
|
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
|
||||||
|
@ -97,7 +96,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||||
|
|
||||||
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float]=None):
|
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
|
||||||
"""
|
"""
|
||||||
:param training: the training set on which to optimize the hyperparameters
|
:param training: the training set on which to optimize the hyperparameters
|
||||||
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
||||||
|
@ -118,6 +117,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
def handler(signum, frame):
|
def handler(signum, frame):
|
||||||
self.sout('timeout reached')
|
self.sout('timeout reached')
|
||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
|
|
||||||
signal.signal(signal.SIGALRM, handler)
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
|
||||||
self.sout(f'starting optimization with n_jobs={n_jobs}')
|
self.sout(f'starting optimization with n_jobs={n_jobs}')
|
||||||
|
@ -175,6 +175,10 @@ class GridSearchQ(BaseQuantifier):
|
||||||
def quantify(self, instances):
|
def quantify(self, instances):
|
||||||
return self.best_model_.quantify(instances)
|
return self.best_model_.quantify(instances)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes_(self):
|
||||||
|
return self.best_model_.classes_
|
||||||
|
|
||||||
def set_params(self, **parameters):
|
def set_params(self, **parameters):
|
||||||
self.param_grid = parameters
|
self.param_grid = parameters
|
||||||
|
|
||||||
|
@ -185,4 +189,3 @@ class GridSearchQ(BaseQuantifier):
|
||||||
if hasattr(self, 'best_model_'):
|
if hasattr(self, 'best_model_'):
|
||||||
return self.best_model_
|
return self.best_model_
|
||||||
raise ValueError('best_model called before fit')
|
raise ValueError('best_model called before fit')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue