QuAcc/quacc/main_test.py

106 lines
3.0 KiB
Python
Raw Normal View History

from copy import deepcopy
from time import time
import numpy as np
import win11toast
from quapy.method.aggregative import SLD
from quapy.protocol import APP, UPP
from sklearn.linear_model import LogisticRegression
from quacc.dataset import Dataset
from quacc.error import acc
2023-11-05 14:16:53 +01:00
from quacc.evaluation.baseline import ref
from quacc.evaluation.method import mulmc_sld
from quacc.evaluation.report import CompReport, EvaluationReport
2023-11-03 23:28:40 +01:00
from quacc.method.base import BinaryQuantifierAccuracyEstimator
from quacc.method.model_selection import GridSearchAE
def test_gs():
d = Dataset(name="rcv1", target="CCAT", n_prevalences=1).get_raw()
classifier = LogisticRegression()
classifier.fit(*d.train.Xy)
quantifier = SLD(LogisticRegression())
2023-11-03 23:28:40 +01:00
# estimator = MultiClassAccuracyEstimator(classifier, quantifier)
estimator = BinaryQuantifierAccuracyEstimator(classifier, quantifier)
v_train, v_val = d.validation.split_stratified(0.6, random_state=0)
gs_protocol = UPP(v_val, sample_size=1000, repeats=100)
gs_estimator = GridSearchAE(
model=deepcopy(estimator),
param_grid={
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
2023-11-03 23:28:40 +01:00
"q__recalib": [None, "bcts", "ts"],
},
refit=False,
protocol=gs_protocol,
verbose=True,
).fit(v_train)
2023-11-03 23:28:40 +01:00
estimator.fit(d.validation)
tstart = time()
erb, ergs = EvaluationReport("base"), EvaluationReport("gs")
protocol = APP(
d.test,
sample_size=1000,
n_prevalences=21,
repeats=100,
return_type="labelled_collection",
)
for sample in protocol():
e_sample = gs_estimator.extend(sample)
estim_prev_b = estimator.estimate(e_sample.X, ext=True)
estim_prev_gs = gs_estimator.estimate(e_sample.X, ext=True)
erb.append_row(
sample.prevalence(),
acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_b)),
)
ergs.append_row(
sample.prevalence(),
acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_gs)),
)
cr = CompReport(
[erb, ergs],
"test",
train_prev=d.train_prev,
valid_prev=d.validation_prev,
)
print(cr.table())
print(f"[took {time() - tstart:.3f}s]")
win11toast.notify("Test", "completed")
2023-11-05 14:16:53 +01:00
def test_mc():
d = Dataset(name="rcv1", target="CCAT", prevs=[0.9]).get()[0]
classifier = LogisticRegression().fit(*d.train.Xy)
protocol = APP(
d.test,
sample_size=1000,
repeats=100,
n_prevalences=21,
return_type="labelled_collection",
)
ref_er = ref(classifier, d.validation, protocol)
mulmc_er = mulmc_sld(classifier, d.validation, protocol)
cr = CompReport(
[mulmc_er, ref_er],
name="test_mc",
train_prev=d.train_prev,
valid_prev=d.validation_prev,
)
with open("test_mc.md", "w") as f:
f.write(cr.data().to_markdown())
if __name__ == "__main__":
2023-11-05 14:16:53 +01:00
test_mc()