From 4bc3e0871147e284a5d6ebaa492d7e7ff258ef8c Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Mon, 6 Nov 2023 21:28:04 +0100 Subject: [PATCH] gsq method fixed --- quacc/method/model_selection.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/quacc/method/model_selection.py b/quacc/method/model_selection.py index 2db3f67..ebc2fb8 100644 --- a/quacc/method/model_selection.py +++ b/quacc/method/model_selection.py @@ -2,6 +2,7 @@ import itertools from copy import deepcopy from time import time from typing import Callable, Union +import numpy as np import quapy as qp from quapy.data import LabelledCollection @@ -189,7 +190,7 @@ class GridSearchAE(BaseAccuracyEstimator): by the model selection process. """ - assert hasattr(self, "best_model_"), "quantify called before fit" + assert hasattr(self, "best_model_"), "estimate called before fit" return self.best_model().estimate(instances, ext=ext) def set_params(self, **parameters): @@ -219,6 +220,7 @@ class GridSearchAE(BaseAccuracyEstimator): raise ValueError("best_model called before fit") + class MCAEgsq(MultiClassAccuracyEstimator): def __init__( self, @@ -255,6 +257,11 @@ class MCAEgsq(MultiClassAccuracyEstimator): return self + def estimate(self, instances, ext=False) -> np.ndarray: + e_inst = instances if ext else self._extend_instances(instances) + estim_prev = self.quantifier.quantify(e_inst) + return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_) + class BQAEgsq(BinaryQuantifierAccuracyEstimator): def __init__(