1
0
Fork 0

hierarchical class problem?

This commit is contained in:
Alejandro Moreo Fernandez 2023-11-13 12:42:57 +01:00
parent 44bfc7921f
commit c9c4511c0d
1 changed files with 9 additions and 6 deletions

View File

@ -7,6 +7,8 @@ from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_predict from sklearn.model_selection import cross_val_predict
from typing_extensions import override
import quapy as qp import quapy as qp
import quapy.functional as F import quapy.functional as F
from functional import get_divergence from functional import get_divergence
@ -19,7 +21,7 @@ from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
# Abstract classes # Abstract classes
# ------------------------------------ # ------------------------------------
class AggregativeQuantifier(ABC, BaseQuantifier): class AggregativeQuantifier(BaseQuantifier, ABC):
""" """
Abstract class for quantification methods that base their estimations on the aggregation of classification Abstract class for quantification methods that base their estimations on the aggregation of classification
results. Aggregative quantifiers implement a pipeline that consists of generating classification predictions results. Aggregative quantifiers implement a pipeline that consists of generating classification predictions
@ -65,7 +67,8 @@ class AggregativeQuantifier(ABC, BaseQuantifier):
""" """
assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean' assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean'
self.__check_classifier(adapt_if_necessary=(self.__classifier_method=='predict_proba')) print(type(self))
self.__check_classifier(adapt_if_necessary=(self.__classifier_method()=='predict_proba'))
if predict_on is None: if predict_on is None:
if fit_classifier: if fit_classifier:
@ -149,12 +152,12 @@ class AggregativeQuantifier(ABC, BaseQuantifier):
""" """
return self.classifier.predict(instances) return self.classifier.predict(instances)
@property
def __classifier_method(self): def __classifier_method(self):
print('using predict')
return 'predict' return 'predict'
def __check_classifier(self, adapt_if_necessary=False): def __check_classifier(self, adapt_if_necessary=False):
assert hasattr(self.classifier, 'predict') assert hasattr(self.classifier, self.__classifier_method())
def quantify(self, instances): def quantify(self, instances):
""" """
@ -199,12 +202,12 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier, ABC):
def classify(self, instances): def classify(self, instances):
return self.classifier.predict_proba(instances) return self.classifier.predict_proba(instances)
@property
def __classifier_method(self): def __classifier_method(self):
print('using predict_proba')
return 'predict_proba' return 'predict_proba'
def __check_classifier(self, adapt_if_necessary=False): def __check_classifier(self, adapt_if_necessary=False):
if not hasattr(self.classifier, 'predict_proba'): if not hasattr(self.classifier, self.__check_classifier()):
if adapt_if_necessary: if adapt_if_necessary:
print(f'warning: The learner {self.classifier.__class__.__name__} does not seem to be ' print(f'warning: The learner {self.classifier.__class__.__name__} does not seem to be '
f'probabilistic. The learner will be calibrated (using CalibratedClassifierCV).') f'probabilistic. The learner will be calibrated (using CalibratedClassifierCV).')