gsq method fixed

This commit is contained in:
Lorenzo Volpi 2023-11-06 21:28:04 +01:00
parent a19e444592
commit 4bc3e08711
1 changed files with 8 additions and 1 deletions

View File

@ -2,6 +2,7 @@ import itertools
from copy import deepcopy
from time import time
from typing import Callable, Union
import numpy as np
import quapy as qp
from quapy.data import LabelledCollection
@ -189,7 +190,7 @@ class GridSearchAE(BaseAccuracyEstimator):
by the model selection process.
"""
assert hasattr(self, "best_model_"), "quantify called before fit"
assert hasattr(self, "best_model_"), "estimate called before fit"
return self.best_model().estimate(instances, ext=ext)
def set_params(self, **parameters):
@ -219,6 +220,7 @@ class GridSearchAE(BaseAccuracyEstimator):
raise ValueError("best_model called before fit")
class MCAEgsq(MultiClassAccuracyEstimator):
def __init__(
self,
@ -255,6 +257,11 @@ class MCAEgsq(MultiClassAccuracyEstimator):
return self
def estimate(self, instances, ext=False) -> np.ndarray:
e_inst = instances if ext else self._extend_instances(instances)
estim_prev = self.quantifier.quantify(e_inst)
return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_)
class BQAEgsq(BinaryQuantifierAccuracyEstimator):
def __init__(