gsq method fixed
This commit is contained in:
parent
ba09d7efbf
commit
a19e444592
|
|
@ -7,12 +7,13 @@ from quapy.method.aggregative import SLD
|
||||||
from quapy.protocol import APP, UPP
|
from quapy.protocol import APP, UPP
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
|
import quacc as qc
|
||||||
from quacc.dataset import Dataset
|
from quacc.dataset import Dataset
|
||||||
from quacc.error import acc
|
from quacc.error import acc
|
||||||
from quacc.evaluation.baseline import ref
|
from quacc.evaluation.baseline import ref
|
||||||
from quacc.evaluation.method import mulmc_sld
|
from quacc.evaluation.method import mulmc_sld
|
||||||
from quacc.evaluation.report import CompReport, EvaluationReport
|
from quacc.evaluation.report import CompReport, EvaluationReport
|
||||||
from quacc.method.base import BinaryQuantifierAccuracyEstimator
|
from quacc.method.base import MCAE, BinaryQuantifierAccuracyEstimator
|
||||||
from quacc.method.model_selection import GridSearchAE
|
from quacc.method.model_selection import GridSearchAE
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -101,5 +102,19 @@ def test_mc():
|
||||||
f.write(cr.data().to_markdown())
|
f.write(cr.data().to_markdown())
|
||||||
|
|
||||||
|
|
||||||
|
def test_et():
|
||||||
|
d = Dataset(name="imdb", prevs=[0.5]).get()[0]
|
||||||
|
classifier = LogisticRegression().fit(*d.train.Xy)
|
||||||
|
estimator = MCAE(
|
||||||
|
classifier,
|
||||||
|
SLD(LogisticRegression(), exact_train_prev=False),
|
||||||
|
confidence="max_conf",
|
||||||
|
).fit(d.validation)
|
||||||
|
e_test = estimator.extend(d.test)
|
||||||
|
ep = estimator.estimate(e_test.X, ext=True)
|
||||||
|
print(f"{qc.error.acc(ep) = }")
|
||||||
|
print(f"{qc.error.acc(e_test.prevalence()) = }")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_mc()
|
test_et()
|
||||||
|
|
|
||||||
|
|
@ -107,10 +107,9 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
e_inst = instances if ext else self._extend_instances(instances)
|
e_inst = instances if ext else self._extend_instances(instances)
|
||||||
|
|
||||||
estim_prev = self.quantifier.quantify(e_inst)
|
estim_prev = self.quantifier.quantify(e_inst)
|
||||||
return self._check_prevalence_classes(estim_prev)
|
return self._check_prevalence_classes(estim_prev, self.quantifier.classes_)
|
||||||
|
|
||||||
def _check_prevalence_classes(self, estim_prev) -> np.ndarray:
|
def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray:
|
||||||
estim_classes = self.quantifier.classes_
|
|
||||||
true_classes = self.e_train.classes_
|
true_classes = self.e_train.classes_
|
||||||
for _cls in true_classes:
|
for _cls in true_classes:
|
||||||
if _cls not in estim_classes:
|
if _cls not in estim_classes:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue