gsq method fixed
This commit is contained in:
parent
ba09d7efbf
commit
a19e444592
|
@ -7,12 +7,13 @@ from quapy.method.aggregative import SLD
|
|||
from quapy.protocol import APP, UPP
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
import quacc as qc
|
||||
from quacc.dataset import Dataset
|
||||
from quacc.error import acc
|
||||
from quacc.evaluation.baseline import ref
|
||||
from quacc.evaluation.method import mulmc_sld
|
||||
from quacc.evaluation.report import CompReport, EvaluationReport
|
||||
from quacc.method.base import BinaryQuantifierAccuracyEstimator
|
||||
from quacc.method.base import MCAE, BinaryQuantifierAccuracyEstimator
|
||||
from quacc.method.model_selection import GridSearchAE
|
||||
|
||||
|
||||
|
@ -101,5 +102,19 @@ def test_mc():
|
|||
f.write(cr.data().to_markdown())
|
||||
|
||||
|
||||
def test_et():
|
||||
d = Dataset(name="imdb", prevs=[0.5]).get()[0]
|
||||
classifier = LogisticRegression().fit(*d.train.Xy)
|
||||
estimator = MCAE(
|
||||
classifier,
|
||||
SLD(LogisticRegression(), exact_train_prev=False),
|
||||
confidence="max_conf",
|
||||
).fit(d.validation)
|
||||
e_test = estimator.extend(d.test)
|
||||
ep = estimator.estimate(e_test.X, ext=True)
|
||||
print(f"{qc.error.acc(ep) = }")
|
||||
print(f"{qc.error.acc(e_test.prevalence()) = }")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mc()
|
||||
test_et()
|
||||
|
|
|
@ -107,10 +107,9 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
|||
e_inst = instances if ext else self._extend_instances(instances)
|
||||
|
||||
estim_prev = self.quantifier.quantify(e_inst)
|
||||
return self._check_prevalence_classes(estim_prev)
|
||||
return self._check_prevalence_classes(estim_prev, self.quantifier.classes_)
|
||||
|
||||
def _check_prevalence_classes(self, estim_prev) -> np.ndarray:
|
||||
estim_classes = self.quantifier.classes_
|
||||
def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray:
|
||||
true_classes = self.e_train.classes_
|
||||
for _cls in true_classes:
|
||||
if _cls not in estim_classes:
|
||||
|
|
Loading…
Reference in New Issue