gsq method fixed

This commit is contained in:
Lorenzo Volpi 2023-11-06 21:28:04 +01:00
parent a19e444592
commit 4bc3e08711
1 changed files with 8 additions and 1 deletions

View File

@ -2,6 +2,7 @@ import itertools
from copy import deepcopy from copy import deepcopy
from time import time from time import time
from typing import Callable, Union from typing import Callable, Union
import numpy as np
import quapy as qp import quapy as qp
from quapy.data import LabelledCollection from quapy.data import LabelledCollection
@ -189,7 +190,7 @@ class GridSearchAE(BaseAccuracyEstimator):
by the model selection process. by the model selection process.
""" """
assert hasattr(self, "best_model_"), "quantify called before fit" assert hasattr(self, "best_model_"), "estimate called before fit"
return self.best_model().estimate(instances, ext=ext) return self.best_model().estimate(instances, ext=ext)
def set_params(self, **parameters): def set_params(self, **parameters):
@ -219,6 +220,7 @@ class GridSearchAE(BaseAccuracyEstimator):
raise ValueError("best_model called before fit") raise ValueError("best_model called before fit")
class MCAEgsq(MultiClassAccuracyEstimator): class MCAEgsq(MultiClassAccuracyEstimator):
def __init__( def __init__(
self, self,
@ -255,6 +257,11 @@ class MCAEgsq(MultiClassAccuracyEstimator):
return self return self
def estimate(self, instances, ext=False) -> np.ndarray:
e_inst = instances if ext else self._extend_instances(instances)
estim_prev = self.quantifier.quantify(e_inst)
return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_)
class BQAEgsq(BinaryQuantifierAccuracyEstimator): class BQAEgsq(BinaryQuantifierAccuracyEstimator):
def __init__( def __init__(