forked from moreo/QuaPy
hierarchical class problem?
This commit is contained in:
parent
44bfc7921f
commit
c9c4511c0d
|
@ -7,6 +7,8 @@ from sklearn.base import BaseEstimator
|
||||||
from sklearn.calibration import CalibratedClassifierCV
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
from sklearn.metrics import confusion_matrix
|
from sklearn.metrics import confusion_matrix
|
||||||
from sklearn.model_selection import cross_val_predict
|
from sklearn.model_selection import cross_val_predict
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
import quapy.functional as F
|
import quapy.functional as F
|
||||||
from functional import get_divergence
|
from functional import get_divergence
|
||||||
|
@ -19,7 +21,7 @@ from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
|
||||||
# Abstract classes
|
# Abstract classes
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
|
|
||||||
class AggregativeQuantifier(ABC, BaseQuantifier):
|
class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for quantification methods that base their estimations on the aggregation of classification
|
Abstract class for quantification methods that base their estimations on the aggregation of classification
|
||||||
results. Aggregative quantifiers implement a pipeline that consists of generating classification predictions
|
results. Aggregative quantifiers implement a pipeline that consists of generating classification predictions
|
||||||
|
@ -65,7 +67,8 @@ class AggregativeQuantifier(ABC, BaseQuantifier):
|
||||||
"""
|
"""
|
||||||
assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean'
|
assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean'
|
||||||
|
|
||||||
self.__check_classifier(adapt_if_necessary=(self.__classifier_method=='predict_proba'))
|
print(type(self))
|
||||||
|
self.__check_classifier(adapt_if_necessary=(self.__classifier_method()=='predict_proba'))
|
||||||
|
|
||||||
if predict_on is None:
|
if predict_on is None:
|
||||||
if fit_classifier:
|
if fit_classifier:
|
||||||
|
@ -149,12 +152,12 @@ class AggregativeQuantifier(ABC, BaseQuantifier):
|
||||||
"""
|
"""
|
||||||
return self.classifier.predict(instances)
|
return self.classifier.predict(instances)
|
||||||
|
|
||||||
@property
|
|
||||||
def __classifier_method(self):
|
def __classifier_method(self):
|
||||||
|
print('using predict')
|
||||||
return 'predict'
|
return 'predict'
|
||||||
|
|
||||||
def __check_classifier(self, adapt_if_necessary=False):
|
def __check_classifier(self, adapt_if_necessary=False):
|
||||||
assert hasattr(self.classifier, 'predict')
|
assert hasattr(self.classifier, self.__classifier_method())
|
||||||
|
|
||||||
def quantify(self, instances):
|
def quantify(self, instances):
|
||||||
"""
|
"""
|
||||||
|
@ -199,12 +202,12 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier, ABC):
|
||||||
def classify(self, instances):
|
def classify(self, instances):
|
||||||
return self.classifier.predict_proba(instances)
|
return self.classifier.predict_proba(instances)
|
||||||
|
|
||||||
@property
|
|
||||||
def __classifier_method(self):
|
def __classifier_method(self):
|
||||||
|
print('using predict_proba')
|
||||||
return 'predict_proba'
|
return 'predict_proba'
|
||||||
|
|
||||||
def __check_classifier(self, adapt_if_necessary=False):
|
def __check_classifier(self, adapt_if_necessary=False):
|
||||||
if not hasattr(self.classifier, 'predict_proba'):
|
if not hasattr(self.classifier, self.__check_classifier()):
|
||||||
if adapt_if_necessary:
|
if adapt_if_necessary:
|
||||||
print(f'warning: The learner {self.classifier.__class__.__name__} does not seem to be '
|
print(f'warning: The learner {self.classifier.__class__.__name__} does not seem to be '
|
||||||
f'probabilistic. The learner will be calibrated (using CalibratedClassifierCV).')
|
f'probabilistic. The learner will be calibrated (using CalibratedClassifierCV).')
|
||||||
|
|
Loading…
Reference in New Issue