gsq method fixed
This commit is contained in:
parent
a19e444592
commit
4bc3e08711
|
@ -2,6 +2,7 @@ import itertools
|
|||
from copy import deepcopy
|
||||
from time import time
|
||||
from typing import Callable, Union
|
||||
import numpy as np
|
||||
|
||||
import quapy as qp
|
||||
from quapy.data import LabelledCollection
|
||||
|
@ -189,7 +190,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
by the model selection process.
|
||||
"""
|
||||
|
||||
assert hasattr(self, "best_model_"), "quantify called before fit"
|
||||
assert hasattr(self, "best_model_"), "estimate called before fit"
|
||||
return self.best_model().estimate(instances, ext=ext)
|
||||
|
||||
def set_params(self, **parameters):
|
||||
|
@ -219,6 +220,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
raise ValueError("best_model called before fit")
|
||||
|
||||
|
||||
|
||||
class MCAEgsq(MultiClassAccuracyEstimator):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -255,6 +257,11 @@ class MCAEgsq(MultiClassAccuracyEstimator):
|
|||
|
||||
return self
|
||||
|
||||
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||
e_inst = instances if ext else self._extend_instances(instances)
|
||||
estim_prev = self.quantifier.quantify(e_inst)
|
||||
return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_)
|
||||
|
||||
|
||||
class BQAEgsq(BinaryQuantifierAccuracyEstimator):
|
||||
def __init__(
|
||||
|
|
Loading…
Reference in New Issue