gsq method fixed

This commit is contained in:
Lorenzo Volpi 2023-11-06 21:27:49 +01:00
parent ba09d7efbf
commit a19e444592
2 changed files with 19 additions and 5 deletions

View File

@ -7,12 +7,13 @@ from quapy.method.aggregative import SLD
from quapy.protocol import APP, UPP
from sklearn.linear_model import LogisticRegression
import quacc as qc
from quacc.dataset import Dataset
from quacc.error import acc
from quacc.evaluation.baseline import ref
from quacc.evaluation.method import mulmc_sld
from quacc.evaluation.report import CompReport, EvaluationReport
from quacc.method.base import BinaryQuantifierAccuracyEstimator
from quacc.method.base import MCAE, BinaryQuantifierAccuracyEstimator
from quacc.method.model_selection import GridSearchAE
@ -101,5 +102,19 @@ def test_mc():
f.write(cr.data().to_markdown())
def test_et():
d = Dataset(name="imdb", prevs=[0.5]).get()[0]
classifier = LogisticRegression().fit(*d.train.Xy)
estimator = MCAE(
classifier,
SLD(LogisticRegression(), exact_train_prev=False),
confidence="max_conf",
).fit(d.validation)
e_test = estimator.extend(d.test)
ep = estimator.estimate(e_test.X, ext=True)
print(f"{qc.error.acc(ep) = }")
print(f"{qc.error.acc(e_test.prevalence()) = }")
if __name__ == "__main__":
test_mc()
test_et()

View File

@ -107,10 +107,9 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
e_inst = instances if ext else self._extend_instances(instances)
estim_prev = self.quantifier.quantify(e_inst)
return self._check_prevalence_classes(estim_prev)
return self._check_prevalence_classes(estim_prev, self.quantifier.classes_)
def _check_prevalence_classes(self, estim_prev) -> np.ndarray:
estim_classes = self.quantifier.classes_
def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray:
true_classes = self.e_train.classes_
for _cls in true_classes:
if _cls not in estim_classes: