gsq method fixed

This commit is contained in:
Lorenzo Volpi 2023-11-06 21:27:49 +01:00
parent ba09d7efbf
commit a19e444592
2 changed files with 19 additions and 5 deletions

View File

@ -7,12 +7,13 @@ from quapy.method.aggregative import SLD
from quapy.protocol import APP, UPP from quapy.protocol import APP, UPP
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
import quacc as qc
from quacc.dataset import Dataset from quacc.dataset import Dataset
from quacc.error import acc from quacc.error import acc
from quacc.evaluation.baseline import ref from quacc.evaluation.baseline import ref
from quacc.evaluation.method import mulmc_sld from quacc.evaluation.method import mulmc_sld
from quacc.evaluation.report import CompReport, EvaluationReport from quacc.evaluation.report import CompReport, EvaluationReport
from quacc.method.base import BinaryQuantifierAccuracyEstimator from quacc.method.base import MCAE, BinaryQuantifierAccuracyEstimator
from quacc.method.model_selection import GridSearchAE from quacc.method.model_selection import GridSearchAE
@ -101,5 +102,19 @@ def test_mc():
f.write(cr.data().to_markdown()) f.write(cr.data().to_markdown())
def test_et():
d = Dataset(name="imdb", prevs=[0.5]).get()[0]
classifier = LogisticRegression().fit(*d.train.Xy)
estimator = MCAE(
classifier,
SLD(LogisticRegression(), exact_train_prev=False),
confidence="max_conf",
).fit(d.validation)
e_test = estimator.extend(d.test)
ep = estimator.estimate(e_test.X, ext=True)
print(f"{qc.error.acc(ep) = }")
print(f"{qc.error.acc(e_test.prevalence()) = }")
if __name__ == "__main__": if __name__ == "__main__":
test_mc() test_et()

View File

@ -107,10 +107,9 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
e_inst = instances if ext else self._extend_instances(instances) e_inst = instances if ext else self._extend_instances(instances)
estim_prev = self.quantifier.quantify(e_inst) estim_prev = self.quantifier.quantify(e_inst)
return self._check_prevalence_classes(estim_prev) return self._check_prevalence_classes(estim_prev, self.quantifier.classes_)
def _check_prevalence_classes(self, estim_prev) -> np.ndarray: def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray:
estim_classes = self.quantifier.classes_
true_classes = self.e_train.classes_ true_classes = self.e_train.classes_
for _cls in true_classes: for _cls in true_classes:
if _cls not in estim_classes: if _cls not in estim_classes: