From a19e444592b8f164c173132ccce173cc25a85eb8 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Mon, 6 Nov 2023 21:27:49 +0100 Subject: [PATCH] gsq method fixed --- quacc/main_test.py | 19 +++++++++++++++++-- quacc/method/base.py | 5 ++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/quacc/main_test.py b/quacc/main_test.py index e80a264..6e47891 100644 --- a/quacc/main_test.py +++ b/quacc/main_test.py @@ -7,12 +7,13 @@ from quapy.method.aggregative import SLD from quapy.protocol import APP, UPP from sklearn.linear_model import LogisticRegression +import quacc as qc from quacc.dataset import Dataset from quacc.error import acc from quacc.evaluation.baseline import ref from quacc.evaluation.method import mulmc_sld from quacc.evaluation.report import CompReport, EvaluationReport -from quacc.method.base import BinaryQuantifierAccuracyEstimator +from quacc.method.base import MCAE, BinaryQuantifierAccuracyEstimator from quacc.method.model_selection import GridSearchAE @@ -101,5 +102,19 @@ def test_mc(): f.write(cr.data().to_markdown()) +def test_et(): + d = Dataset(name="imdb", prevs=[0.5]).get()[0] + classifier = LogisticRegression().fit(*d.train.Xy) + estimator = MCAE( + classifier, + SLD(LogisticRegression(), exact_train_prev=False), + confidence="max_conf", + ).fit(d.validation) + e_test = estimator.extend(d.test) + ep = estimator.estimate(e_test.X, ext=True) + print(f"{qc.error.acc(ep) = }") + print(f"{qc.error.acc(e_test.prevalence()) = }") + + if __name__ == "__main__": - test_mc() + test_et() diff --git a/quacc/method/base.py b/quacc/method/base.py index a7389f4..670abb7 100644 --- a/quacc/method/base.py +++ b/quacc/method/base.py @@ -107,10 +107,9 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): e_inst = instances if ext else self._extend_instances(instances) estim_prev = self.quantifier.quantify(e_inst) - return self._check_prevalence_classes(estim_prev) + return self._check_prevalence_classes(estim_prev, self.quantifier.classes_) - def _check_prevalence_classes(self, estim_prev) -> np.ndarray: - estim_classes = self.quantifier.classes_ + def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray: true_classes = self.e_train.classes_ for _cls in true_classes: if _cls not in estim_classes: