grid search base implementation, MCAE adapted
This commit is contained in:
parent
19b99900e5
commit
eccd818719
|
@ -5,22 +5,21 @@
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
|
|
||||||
|
|
||||||
{
|
{
|
||||||
"name": "main",
|
"name": "main",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main.py",
|
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "models",
|
"name": "main_test",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\baselines\\models.py",
|
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main_test.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
11
conf.yaml
11
conf.yaml
|
@ -5,7 +5,6 @@ debug_conf: &debug_conf
|
||||||
DATASET_N_PREVS: 5
|
DATASET_N_PREVS: 5
|
||||||
DATASET_PREVS:
|
DATASET_PREVS:
|
||||||
- 0.5
|
- 0.5
|
||||||
- 0.1
|
|
||||||
|
|
||||||
confs:
|
confs:
|
||||||
- DATASET_NAME: imdb
|
- DATASET_NAME: imdb
|
||||||
|
@ -13,17 +12,9 @@ debug_conf: &debug_conf
|
||||||
plot_confs:
|
plot_confs:
|
||||||
debug:
|
debug:
|
||||||
PLOT_ESTIMATORS:
|
PLOT_ESTIMATORS:
|
||||||
|
- mul_sld_gs
|
||||||
- ref
|
- ref
|
||||||
- atc_mc
|
|
||||||
- atc_ne
|
|
||||||
PLOT_STDEV: true
|
PLOT_STDEV: true
|
||||||
debug_plus:
|
|
||||||
PLOT_ESTIMATORS:
|
|
||||||
- mul_sld_bcts
|
|
||||||
- mul_sld
|
|
||||||
- ref
|
|
||||||
- atc_mc
|
|
||||||
- atc_ne
|
|
||||||
|
|
||||||
test_conf: &test_conf
|
test_conf: &test_conf
|
||||||
global:
|
global:
|
||||||
|
|
21
quacc.log
21
quacc.log
|
@ -1473,3 +1473,24 @@
|
||||||
31/10/23 17:05:50| INFO mul_sld finished [took 29.3523s]
|
31/10/23 17:05:50| INFO mul_sld finished [took 29.3523s]
|
||||||
31/10/23 17:06:00| INFO mul_sld_bcts finished [took 39.8376s]
|
31/10/23 17:06:00| INFO mul_sld_bcts finished [took 39.8376s]
|
||||||
31/10/23 17:06:00| INFO Dataset sample 0.50 of dataset imdb_2prevs finished [took 41.2888s]
|
31/10/23 17:06:00| INFO Dataset sample 0.50 of dataset imdb_2prevs finished [took 41.2888s]
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
31/10/23 20:19:37| INFO dataset imdb_1prevs
|
||||||
|
31/10/23 20:19:48| INFO Dataset sample 0.50 of dataset imdb_1prevs started
|
||||||
|
31/10/23 20:20:07| INFO ref finished [took 17.4125s]
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
31/10/23 20:20:50| INFO dataset imdb_1prevs
|
||||||
|
31/10/23 20:21:01| INFO Dataset sample 0.50 of dataset imdb_1prevs started
|
||||||
|
31/10/23 20:21:19| INFO ref finished [took 17.0717s]
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
31/10/23 20:22:05| INFO dataset imdb_1prevs
|
||||||
|
31/10/23 20:22:15| INFO Dataset sample 0.50 of dataset imdb_1prevs started
|
||||||
|
31/10/23 20:22:35| INFO ref finished [took 18.4752s]
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
31/10/23 20:23:38| INFO dataset imdb_1prevs
|
||||||
|
31/10/23 20:23:48| INFO Dataset sample 0.50 of dataset imdb_1prevs started
|
||||||
|
31/10/23 20:24:08| INFO ref finished [took 18.3216s]
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
01/11/23 13:07:19| INFO dataset imdb_1prevs
|
||||||
|
01/11/23 13:07:27| INFO Dataset sample 0.50 of dataset imdb_1prevs started
|
||||||
|
01/11/23 13:07:27| ERROR Evaluation over imdb_1prevs failed. Exception: 'Invalid estimator: estimator mul_sld_gs does not exist'
|
||||||
|
01/11/23 13:07:27| ERROR Failed while saving configuration imdb_debug of imdb_1prevs. Exception: cannot access local variable 'dr' where it is not associated with a value
|
||||||
|
|
|
@ -71,18 +71,16 @@ class Dataset:
|
||||||
|
|
||||||
return all_train, test
|
return all_train, test
|
||||||
|
|
||||||
def get_raw(self, validation=True) -> DatasetSample:
|
def get_raw(self) -> DatasetSample:
|
||||||
all_train, test = {
|
all_train, test = {
|
||||||
"spambase": self.__spambase,
|
"spambase": self.__spambase,
|
||||||
"imdb": self.__imdb,
|
"imdb": self.__imdb,
|
||||||
"rcv1": self.__rcv1,
|
"rcv1": self.__rcv1,
|
||||||
}[self._name]()
|
}[self._name]()
|
||||||
|
|
||||||
train, val = all_train, None
|
train, val = all_train.split_stratified(
|
||||||
if validation:
|
train_prop=TRAIN_VAL_PROP, random_state=0
|
||||||
train, val = all_train.split_stratified(
|
)
|
||||||
train_prop=TRAIN_VAL_PROP, random_state=0
|
|
||||||
)
|
|
||||||
|
|
||||||
return DatasetSample(train, val, test)
|
return DatasetSample(train, val, test)
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
import quapy as qp
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def from_name(err_name):
|
def from_name(err_name):
|
||||||
if err_name == "f1e":
|
assert err_name in ERROR_NAMES, f"unknown error {err_name}"
|
||||||
return f1e
|
callable_error = globals()[err_name]
|
||||||
elif err_name == "f1":
|
return callable_error
|
||||||
return f1
|
|
||||||
else:
|
|
||||||
return qp.error.from_name(err_name)
|
|
||||||
|
|
||||||
|
|
||||||
# def f1(prev):
|
# def f1(prev):
|
||||||
|
@ -36,5 +33,23 @@ def f1e(prev):
|
||||||
return 1 - f1(prev)
|
return 1 - f1(prev)
|
||||||
|
|
||||||
|
|
||||||
def acc(prev):
|
def acc(prev: np.ndarray) -> float:
|
||||||
return (prev[0] + prev[3]) / sum(prev)
|
return (prev[0] + prev[3]) / np.sum(prev)
|
||||||
|
|
||||||
|
|
||||||
|
def accd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> np.ndarray:
|
||||||
|
vacc = np.vectorize(acc, signature="(m)->()")
|
||||||
|
a_tp = vacc(true_prevs)
|
||||||
|
a_ep = vacc(estim_prevs)
|
||||||
|
return np.abs(a_tp - a_ep)
|
||||||
|
|
||||||
|
|
||||||
|
def maccd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> float:
|
||||||
|
return accd(true_prevs, estim_prevs).mean()
|
||||||
|
|
||||||
|
|
||||||
|
ACCURACY_ERROR = {maccd}
|
||||||
|
ACCURACY_ERROR_SINGLE = {accd}
|
||||||
|
ACCURACY_ERROR_NAMES = {func.__name__ for func in ACCURACY_ERROR}
|
||||||
|
ACCURACY_ERROR_SINGLE_NAMES = {func.__name__ for func in ACCURACY_ERROR_SINGLE}
|
||||||
|
ERROR_NAMES = ACCURACY_ERROR_NAMES | ACCURACY_ERROR_SINGLE_NAMES
|
||||||
|
|
|
@ -1,19 +1,16 @@
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sklearn.metrics as metrics
|
from quapy.method.aggregative import SLD
|
||||||
from quapy.data import LabelledCollection
|
from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
|
||||||
from quapy.protocol import AbstractStochasticSeededProtocol
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.base import BaseEstimator
|
|
||||||
|
|
||||||
import quacc.error as error
|
import quacc as qc
|
||||||
from quacc.evaluation.report import EvaluationReport
|
from quacc.evaluation.report import EvaluationReport
|
||||||
|
from quacc.method.model_selection import GridSearchAE
|
||||||
|
|
||||||
from ..estimator import (
|
from ..method.base import BQAE, MCAE, BaseAccuracyEstimator
|
||||||
AccuracyEstimator,
|
|
||||||
BinaryQuantifierAccuracyEstimator,
|
|
||||||
MulticlassAccuracyEstimator,
|
|
||||||
)
|
|
||||||
|
|
||||||
_methods = {}
|
_methods = {}
|
||||||
|
|
||||||
|
@ -28,108 +25,115 @@ def method(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def estimate(
|
def evaluate(
|
||||||
estimator: AccuracyEstimator,
|
estimator: BaseAccuracyEstimator,
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractProtocol,
|
||||||
):
|
error_metric: Union[Callable | str],
|
||||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
|
) -> float:
|
||||||
for sample in protocol():
|
if isinstance(error_metric, str):
|
||||||
e_sample, pred_proba = estimator.extend(sample)
|
error_metric = qc.error.from_name(error_metric)
|
||||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
||||||
base_prevs.append(sample.prevalence())
|
|
||||||
true_prevs.append(e_sample.prevalence())
|
|
||||||
estim_prevs.append(estim_prev)
|
|
||||||
pred_probas.append(pred_proba)
|
|
||||||
labels.append(sample.y)
|
|
||||||
|
|
||||||
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
collator_bck_ = protocol.collator
|
||||||
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
|
estim_prevs, true_prevs = [], []
|
||||||
|
for sample in protocol():
|
||||||
|
e_sample = estimator.extend(sample)
|
||||||
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||||
|
estim_prevs.append(estim_prev)
|
||||||
|
true_prevs.append(e_sample.prevalence())
|
||||||
|
|
||||||
|
protocol.collator = collator_bck_
|
||||||
|
|
||||||
|
true_prevs = np.array(true_prevs)
|
||||||
|
estim_prevs = np.array(estim_prevs)
|
||||||
|
|
||||||
|
return error_metric(true_prevs, estim_prevs)
|
||||||
|
|
||||||
|
|
||||||
def evaluation_report(
|
def evaluation_report(
|
||||||
estimator: AccuracyEstimator,
|
estimator: BaseAccuracyEstimator,
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractProtocol,
|
||||||
method: str,
|
method: str,
|
||||||
) -> EvaluationReport:
|
) -> EvaluationReport:
|
||||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
|
|
||||||
estimator, protocol
|
|
||||||
)
|
|
||||||
report = EvaluationReport(name=method)
|
report = EvaluationReport(name=method)
|
||||||
|
for sample in protocol():
|
||||||
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
|
e_sample = estimator.extend(sample)
|
||||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||||
):
|
acc_score = qc.error.acc(estim_prev)
|
||||||
pred = np.argmax(pred_proba, axis=-1)
|
f1_score = qc.error.f1(estim_prev)
|
||||||
acc_score = error.acc(estim_prev)
|
|
||||||
f1_score = error.f1(estim_prev)
|
|
||||||
report.append_row(
|
report.append_row(
|
||||||
base_prev,
|
sample.prevalence(),
|
||||||
acc_score=acc_score,
|
acc_score=acc_score,
|
||||||
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
|
acc=abs(qc.error.acc(e_sample.prevalence()) - acc_score),
|
||||||
f1_score=f1_score,
|
f1_score=f1_score,
|
||||||
f1=abs(error.f1(true_prev) - f1_score),
|
f1=abs(qc.error.f1(e_sample.prevalence()) - f1_score),
|
||||||
)
|
)
|
||||||
|
|
||||||
report.fit_score = estimator.fit_score
|
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
c_model: BaseEstimator,
|
|
||||||
validation: LabelledCollection,
|
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
|
||||||
method: str,
|
|
||||||
q_model: str,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
estimator: AccuracyEstimator = {
|
|
||||||
"bin": BinaryQuantifierAccuracyEstimator,
|
|
||||||
"mul": MulticlassAccuracyEstimator,
|
|
||||||
}[method](c_model, q_model=q_model.upper(), **kwargs)
|
|
||||||
estimator.fit(validation)
|
|
||||||
_method = f"{method}_{q_model}"
|
|
||||||
if "recalib" in kwargs:
|
|
||||||
_method += f"_{kwargs['recalib']}"
|
|
||||||
if ("gs", True) in kwargs.items():
|
|
||||||
_method += "_gs"
|
|
||||||
return evaluation_report(estimator, protocol, _method)
|
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def bin_sld(c_model, validation, protocol) -> EvaluationReport:
|
def bin_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "bin", "sld")
|
est = BQAE(c_model, SLD(LogisticRegression()))
|
||||||
|
est.fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
method="bin_sld",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def mul_sld(c_model, validation, protocol) -> EvaluationReport:
|
def mul_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "mul", "sld")
|
est = MCAE(c_model, SLD(LogisticRegression()))
|
||||||
|
est.fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocor=protocol,
|
||||||
|
method="mul_sld",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "bin", "sld", recalib="bcts")
|
est = BQAE(c_model, SLD(LogisticRegression(), recalib="bcts"))
|
||||||
|
est.fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
method="bin_sld_bcts",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "mul", "sld", recalib="bcts")
|
est = MCAE(c_model, SLD(LogisticRegression(), recalib="bcts"))
|
||||||
|
est.fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
@method
|
estimator=est,
|
||||||
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
protocol=protocol,
|
||||||
return evaluate(c_model, validation, protocol, "bin", "sld", gs=True)
|
method="mul_sld_bcts",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "mul", "sld", gs=True)
|
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||||
|
model = SLD(LogisticRegression())
|
||||||
|
est = GridSearchAE(
|
||||||
@method
|
model=model,
|
||||||
def bin_cc(c_model, validation, protocol) -> EvaluationReport:
|
param_grid={
|
||||||
return evaluate(c_model, validation, protocol, "bin", "cc")
|
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||||
|
"q__classifier__class_weight": [None, "balanced"],
|
||||||
|
"q__recalib": [None, "bcts", "vs"],
|
||||||
@method
|
},
|
||||||
def mul_cc(c_model, validation, protocol) -> EvaluationReport:
|
refit=False,
|
||||||
return evaluate(c_model, validation, protocol, "mul", "cc")
|
protocol=UPP(v_val, repeats=100),
|
||||||
|
verbose=True,
|
||||||
|
).fit(v_train)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
method="mul_sld_gs",
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import win11toast
|
||||||
|
from quapy.method.aggregative import SLD
|
||||||
|
from quapy.protocol import APP, UPP
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
|
from quacc.dataset import Dataset
|
||||||
|
from quacc.error import acc
|
||||||
|
from quacc.evaluation.report import CompReport, EvaluationReport
|
||||||
|
from quacc.method.base import MultiClassAccuracyEstimator
|
||||||
|
from quacc.method.model_selection import GridSearchAE
|
||||||
|
|
||||||
|
|
||||||
|
def test_gs():
|
||||||
|
d = Dataset(name="rcv1", target="CCAT", n_prevalences=1).get_raw()
|
||||||
|
|
||||||
|
classifier = LogisticRegression()
|
||||||
|
classifier.fit(*d.train.Xy)
|
||||||
|
|
||||||
|
quantifier = SLD(LogisticRegression())
|
||||||
|
estimator = MultiClassAccuracyEstimator(classifier, quantifier)
|
||||||
|
estimator.fit(d.validation)
|
||||||
|
|
||||||
|
v_train, v_val = d.validation.split_stratified(0.6, random_state=0)
|
||||||
|
gs_protocol = UPP(v_val, sample_size=1000, repeats=100)
|
||||||
|
gs_estimator = GridSearchAE(
|
||||||
|
model=deepcopy(estimator),
|
||||||
|
param_grid={
|
||||||
|
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||||
|
"q__classifier__class_weight": [None, "balanced"],
|
||||||
|
"q__recalib": [None, "bcts", "vs"],
|
||||||
|
},
|
||||||
|
refit=False,
|
||||||
|
protocol=gs_protocol,
|
||||||
|
verbose=True,
|
||||||
|
).fit(v_train)
|
||||||
|
|
||||||
|
tstart = time()
|
||||||
|
erb, ergs = EvaluationReport("base"), EvaluationReport("gs")
|
||||||
|
protocol = APP(
|
||||||
|
d.test,
|
||||||
|
sample_size=1000,
|
||||||
|
n_prevalences=21,
|
||||||
|
repeats=100,
|
||||||
|
return_type="labelled_collection",
|
||||||
|
)
|
||||||
|
for sample in protocol():
|
||||||
|
e_sample = gs_estimator.extend(sample)
|
||||||
|
estim_prev_b = estimator.estimate(e_sample.X, ext=True)
|
||||||
|
estim_prev_gs = gs_estimator.estimate(e_sample.X, ext=True)
|
||||||
|
erb.append_row(
|
||||||
|
sample.prevalence(),
|
||||||
|
acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_b)),
|
||||||
|
)
|
||||||
|
ergs.append_row(
|
||||||
|
sample.prevalence(),
|
||||||
|
acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_gs)),
|
||||||
|
)
|
||||||
|
|
||||||
|
cr = CompReport(
|
||||||
|
[erb, ergs],
|
||||||
|
"test",
|
||||||
|
train_prev=d.train_prev,
|
||||||
|
valid_prev=d.validation_prev,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(cr.table())
|
||||||
|
print(f"[took {time() - tstart:.3f}s]")
|
||||||
|
win11toast.notify("Test", "completed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_gs()
|
Binary file not shown.
Binary file not shown.
|
@ -4,7 +4,7 @@ from abc import abstractmethod
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.method.aggregative import CC, SLD
|
from quapy.method.aggregative import CC, SLD, BaseQuantifier
|
||||||
from quapy.model_selection import GridSearchQ
|
from quapy.model_selection import GridSearchQ
|
||||||
from quapy.protocol import UPP
|
from quapy.protocol import UPP
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
|
@ -14,9 +14,22 @@ from sklearn.model_selection import cross_val_predict
|
||||||
from quacc.data import ExtendedCollection
|
from quacc.data import ExtendedCollection
|
||||||
|
|
||||||
|
|
||||||
class AccuracyEstimator:
|
class BaseAccuracyEstimator(BaseQuantifier):
|
||||||
def __init__(self):
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: BaseEstimator,
|
||||||
|
quantifier: BaseQuantifier,
|
||||||
|
):
|
||||||
self.fit_score = None
|
self.fit_score = None
|
||||||
|
self.__check_classifier(classifier)
|
||||||
|
self.classifier = classifier
|
||||||
|
self.quantifier = quantifier
|
||||||
|
|
||||||
|
def __check_classifier(self, classifier):
|
||||||
|
if not hasattr(classifier, "predict_proba"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities."
|
||||||
|
)
|
||||||
|
|
||||||
def _gs_params(self, t_val: LabelledCollection):
|
def _gs_params(self, t_val: LabelledCollection):
|
||||||
return {
|
return {
|
||||||
|
@ -33,85 +46,55 @@ class AccuracyEstimator:
|
||||||
"verbose": True,
|
"verbose": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||||
if not pred_proba:
|
if not pred_proba:
|
||||||
pred_proba = self.c_model.predict_proba(base.X)
|
pred_proba = self.classifier.predict_proba(coll.X)
|
||||||
return ExtendedCollection.extend_collection(base, pred_proba), pred_proba
|
return ExtendedCollection.extend_collection(coll, pred_proba)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def estimate(self, instances, ext=False):
|
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
AE = AccuracyEstimator
|
class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: BaseEstimator,
|
||||||
|
quantifier: BaseQuantifier,
|
||||||
|
):
|
||||||
|
super().__init__(classifier, quantifier)
|
||||||
|
|
||||||
|
def fit(self, train: LabelledCollection):
|
||||||
|
pred_probs = self.classifier.predict_proba(train.X)
|
||||||
|
self.e_train = ExtendedCollection.extend_collection(train, pred_probs)
|
||||||
|
|
||||||
class MulticlassAccuracyEstimator(AccuracyEstimator):
|
self.quantifier.fit(self.e_train)
|
||||||
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
|
|
||||||
super().__init__()
|
|
||||||
self.c_model = c_model
|
|
||||||
self._q_model_name = q_model.upper()
|
|
||||||
self.e_train = None
|
|
||||||
self.gs = gs
|
|
||||||
self.recalib = recalib
|
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
return self
|
||||||
# check if model is fit
|
|
||||||
# self.model.fit(*train.Xy)
|
|
||||||
if isinstance(train, LabelledCollection):
|
|
||||||
pred_prob_train = cross_val_predict(
|
|
||||||
self.c_model, *train.Xy, method="predict_proba"
|
|
||||||
)
|
|
||||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
|
||||||
else:
|
|
||||||
self.e_train = train
|
|
||||||
|
|
||||||
if self._q_model_name == "SLD":
|
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||||
if self.gs:
|
e_inst = instances
|
||||||
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
|
|
||||||
gs_params = self._gs_params(t_val)
|
|
||||||
self.q_model = GridSearchQ(
|
|
||||||
SLD(LogisticRegression()),
|
|
||||||
**gs_params,
|
|
||||||
)
|
|
||||||
self.q_model.fit(t_train)
|
|
||||||
self.fit_score = self.q_model.best_score_
|
|
||||||
else:
|
|
||||||
self.q_model = SLD(LogisticRegression(), recalib=self.recalib)
|
|
||||||
self.q_model.fit(self.e_train)
|
|
||||||
elif self._q_model_name == "CC":
|
|
||||||
self.q_model = CC(LogisticRegression())
|
|
||||||
self.q_model.fit(self.e_train)
|
|
||||||
|
|
||||||
def estimate(self, instances, ext=False):
|
|
||||||
if not ext:
|
if not ext:
|
||||||
pred_prob = self.c_model.predict_proba(instances)
|
pred_prob = self.classifier.predict_proba(instances)
|
||||||
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
||||||
else:
|
|
||||||
e_inst = instances
|
|
||||||
|
|
||||||
estim_prev = self.q_model.quantify(e_inst)
|
estim_prev = self.quantifier.quantify(e_inst)
|
||||||
|
return self._check_prevalence_classes(estim_prev)
|
||||||
|
|
||||||
return self._check_prevalence_classes(
|
def _check_prevalence_classes(self, estim_prev) -> np.ndarray:
|
||||||
self.e_train.classes_, self.q_model, estim_prev
|
estim_classes = self.quantifier.classes_
|
||||||
)
|
true_classes = self.e_train.classes_
|
||||||
|
|
||||||
def _check_prevalence_classes(self, true_classes, q_model, estim_prev):
|
|
||||||
if isinstance(q_model, GridSearchQ):
|
|
||||||
estim_classes = q_model.best_model().classes_
|
|
||||||
else:
|
|
||||||
estim_classes = q_model.classes_
|
|
||||||
for _cls in true_classes:
|
for _cls in true_classes:
|
||||||
if _cls not in estim_classes:
|
if _cls not in estim_classes:
|
||||||
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
||||||
return estim_prev
|
return estim_prev
|
||||||
|
|
||||||
|
|
||||||
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
|
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.c_model = c_model
|
self.c_model = c_model
|
||||||
|
@ -190,3 +173,8 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
||||||
return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst))))
|
return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst))))
|
||||||
else:
|
else:
|
||||||
return np.asarray([0.0, 0.0])
|
return np.asarray([0.0, 0.0])
|
||||||
|
|
||||||
|
|
||||||
|
BAE = BaseAccuracyEstimator
|
||||||
|
MCAE = MultiClassAccuracyEstimator
|
||||||
|
BQAE = BinaryQuantifierAccuracyEstimator
|
|
@ -0,0 +1,204 @@
|
||||||
|
import itertools
|
||||||
|
from copy import deepcopy
|
||||||
|
from time import time
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
|
from quapy.data import LabelledCollection
|
||||||
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||||
|
|
||||||
|
import quacc as qc
|
||||||
|
import quacc.evaluation.method as evaluation
|
||||||
|
from quacc.data import ExtendedCollection
|
||||||
|
from quacc.method.base import BaseAccuracyEstimator
|
||||||
|
|
||||||
|
|
||||||
|
class GridSearchAE(BaseAccuracyEstimator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: BaseAccuracyEstimator,
|
||||||
|
param_grid: dict,
|
||||||
|
protocol: AbstractProtocol,
|
||||||
|
error: Union[Callable, str] = qc.error.maccd,
|
||||||
|
refit=True,
|
||||||
|
# timeout=-1,
|
||||||
|
# n_jobs=None,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.param_grid = self.__normalize_params(param_grid)
|
||||||
|
self.protocol = protocol
|
||||||
|
self.refit = refit
|
||||||
|
# self.timeout = timeout
|
||||||
|
# self.n_jobs = qp._get_njobs(n_jobs)
|
||||||
|
self.verbose = verbose
|
||||||
|
self.__check_error(error)
|
||||||
|
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
|
||||||
|
|
||||||
|
def _sout(self, msg):
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[{self.__class__.__name__}]: {msg}")
|
||||||
|
|
||||||
|
def __normalize_params(self, params):
|
||||||
|
__remap = {}
|
||||||
|
for key in params.keys():
|
||||||
|
k, delim, sub_key = key.partition("__")
|
||||||
|
if delim and k == "q":
|
||||||
|
__remap[key] = f"quantifier__{sub_key}"
|
||||||
|
|
||||||
|
return {(__remap[k] if k in __remap else k): v for k, v in params.items()}
|
||||||
|
|
||||||
|
def __check_error(self, error):
|
||||||
|
if error in qc.error.ACCURACY_ERROR:
|
||||||
|
self.error = error
|
||||||
|
elif isinstance(error, str):
|
||||||
|
self.error = qc.error.from_name(error)
|
||||||
|
elif hasattr(error, "__call__"):
|
||||||
|
self.error = error
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected error type; must either be a callable function or a str representing\n"
|
||||||
|
f"the name of an error function in {qc.error.ACCURACY_ERROR_NAMES}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def fit(self, training: LabelledCollection):
|
||||||
|
"""Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
||||||
|
the error metric.
|
||||||
|
|
||||||
|
:param training: the training set on which to optimize the hyperparameters
|
||||||
|
:return: self
|
||||||
|
"""
|
||||||
|
params_keys = list(self.param_grid.keys())
|
||||||
|
params_values = list(self.param_grid.values())
|
||||||
|
|
||||||
|
protocol = self.protocol
|
||||||
|
|
||||||
|
self.param_scores_ = {}
|
||||||
|
self.best_score_ = None
|
||||||
|
|
||||||
|
tinit = time()
|
||||||
|
|
||||||
|
hyper = [
|
||||||
|
dict(zip(params_keys, val)) for val in itertools.product(*params_values)
|
||||||
|
]
|
||||||
|
|
||||||
|
# self._sout(f"starting model selection with {self.n_jobs =}")
|
||||||
|
self._sout("starting model selection")
|
||||||
|
|
||||||
|
scores = [self.__params_eval(params, training) for params in hyper]
|
||||||
|
|
||||||
|
for params, score, model in scores:
|
||||||
|
if score is not None:
|
||||||
|
if self.best_score_ is None or score < self.best_score_:
|
||||||
|
self.best_score_ = score
|
||||||
|
self.best_params_ = params
|
||||||
|
self.best_model_ = model
|
||||||
|
self.param_scores_[str(params)] = score
|
||||||
|
else:
|
||||||
|
self.param_scores_[str(params)] = "timeout"
|
||||||
|
|
||||||
|
tend = time() - tinit
|
||||||
|
|
||||||
|
if self.best_score_ is None:
|
||||||
|
raise TimeoutError("no combination of hyperparameters seem to work")
|
||||||
|
|
||||||
|
self._sout(
|
||||||
|
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
|
||||||
|
f"[took {tend:.4f}s]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.refit:
|
||||||
|
if isinstance(protocol, OnLabelledCollectionProtocol):
|
||||||
|
self._sout("refitting on the whole development set")
|
||||||
|
self.best_model_.fit(training + protocol.get_labelled_collection())
|
||||||
|
else:
|
||||||
|
raise RuntimeWarning(
|
||||||
|
f'"refit" was requested, but the protocol does not '
|
||||||
|
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __params_eval(self, params, training):
|
||||||
|
protocol = self.protocol
|
||||||
|
error = self.error
|
||||||
|
|
||||||
|
# if self.timeout > 0:
|
||||||
|
|
||||||
|
# def handler(signum, frame):
|
||||||
|
# raise TimeoutError()
|
||||||
|
|
||||||
|
# signal.signal(signal.SIGALRM, handler)
|
||||||
|
|
||||||
|
tinit = time()
|
||||||
|
|
||||||
|
# if self.timeout > 0:
|
||||||
|
# signal.alarm(self.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = deepcopy(self.model)
|
||||||
|
# overrides default parameters with the parameters being explored at this iteration
|
||||||
|
model.set_params(**params)
|
||||||
|
model.fit(training)
|
||||||
|
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
|
||||||
|
|
||||||
|
ttime = time() - tinit
|
||||||
|
self._sout(
|
||||||
|
f"hyperparams={params}\t got score {score:.5f} [took {ttime:.4f}s]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# if self.timeout > 0:
|
||||||
|
# signal.alarm(0)
|
||||||
|
# except TimeoutError:
|
||||||
|
# self._sout(f"timeout ({self.timeout}s) reached for config {params}")
|
||||||
|
# score = None
|
||||||
|
except ValueError as e:
|
||||||
|
self._sout(f"the combination of hyperparameters {params} is invalid")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
self._sout(f"something went wrong for config {params}; skipping:")
|
||||||
|
self._sout(f"\tException: {e}")
|
||||||
|
# traceback(e)
|
||||||
|
score = None
|
||||||
|
|
||||||
|
return params, score, model
|
||||||
|
|
||||||
|
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||||
|
assert hasattr(self, "best_model_"), "quantify called before fit"
|
||||||
|
return self.best_model().extend(coll, pred_proba=pred_proba)
|
||||||
|
|
||||||
|
def estimate(self, instances, ext=False):
|
||||||
|
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
||||||
|
|
||||||
|
:param instances: sample contanining the instances
|
||||||
|
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
|
||||||
|
by the model selection process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert hasattr(self, "best_model_"), "quantify called before fit"
|
||||||
|
return self.best_model().estimate(instances, ext=ext)
|
||||||
|
|
||||||
|
def set_params(self, **parameters):
|
||||||
|
"""Sets the hyper-parameters to explore.
|
||||||
|
|
||||||
|
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
|
||||||
|
"""
|
||||||
|
self.param_grid = parameters
|
||||||
|
|
||||||
|
def get_params(self, deep=True):
|
||||||
|
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
|
||||||
|
|
||||||
|
:param deep: Unused
|
||||||
|
:return: the dictionary `param_grid`
|
||||||
|
"""
|
||||||
|
return self.param_grid
|
||||||
|
|
||||||
|
def best_model(self):
|
||||||
|
"""
|
||||||
|
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
|
||||||
|
of hyper-parameters that minimized the error function.
|
||||||
|
|
||||||
|
:return: a trained quantifier
|
||||||
|
"""
|
||||||
|
if hasattr(self, "best_model_"):
|
||||||
|
return self.best_model_
|
||||||
|
raise ValueError("best_model called before fit")
|
|
@ -1,138 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import scipy as sp
|
|
||||||
import quapy as qp
|
|
||||||
from quapy.data import LabelledCollection
|
|
||||||
from quapy.method.aggregative import SLD
|
|
||||||
from quapy.protocol import APP, AbstractStochasticSeededProtocol
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from sklearn.model_selection import cross_val_predict
|
|
||||||
|
|
||||||
from .data import get_dataset
|
|
||||||
|
|
||||||
# Extended classes
|
|
||||||
#
|
|
||||||
# 0 ~ True 0
|
|
||||||
# 1 ~ False 1
|
|
||||||
# 2 ~ False 0
|
|
||||||
# 3 ~ True 1
|
|
||||||
# _____________________
|
|
||||||
# | | |
|
|
||||||
# | True 0 | False 1 |
|
|
||||||
# |__________|__________|
|
|
||||||
# | | |
|
|
||||||
# | False 0 | True 1 |
|
|
||||||
# |__________|__________|
|
|
||||||
#
|
|
||||||
def get_ex_class(classes, true_class, pred_class):
|
|
||||||
return true_class * classes + pred_class
|
|
||||||
|
|
||||||
|
|
||||||
def extend_collection(coll, pred_prob):
|
|
||||||
n_classes = coll.n_classes
|
|
||||||
|
|
||||||
# n_X = [ X | predicted probs. ]
|
|
||||||
if isinstance(coll.X, sp.csr_matrix):
|
|
||||||
pred_prob_csr = sp.csr_matrix(pred_prob)
|
|
||||||
n_x = sp.hstack([coll.X, pred_prob_csr])
|
|
||||||
elif isinstance(coll.X, np.ndarray):
|
|
||||||
n_x = np.concatenate((coll.X, pred_prob), axis=1)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported matrix format")
|
|
||||||
|
|
||||||
# n_y = (exptected y, predicted y)
|
|
||||||
n_y = []
|
|
||||||
for i, true_class in enumerate(coll.y):
|
|
||||||
pred_class = pred_prob[i].argmax(axis=0)
|
|
||||||
n_y.append(get_ex_class(n_classes, true_class, pred_class))
|
|
||||||
|
|
||||||
return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)])
|
|
||||||
|
|
||||||
|
|
||||||
def qf1e_binary(prev):
|
|
||||||
recall = prev[0] / (prev[0] + prev[1])
|
|
||||||
precision = prev[0] / (prev[0] + prev[2])
|
|
||||||
|
|
||||||
return 1 - 2 * (precision * recall) / (precision + recall)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_errors(true_prev, estim_prev, n_instances):
|
|
||||||
errors = {}
|
|
||||||
_eps = 1 / (2 * n_instances)
|
|
||||||
errors = {
|
|
||||||
"mae": qp.error.mae(true_prev, estim_prev),
|
|
||||||
"rae": qp.error.rae(true_prev, estim_prev, eps=_eps),
|
|
||||||
"mrae": qp.error.mrae(true_prev, estim_prev, eps=_eps),
|
|
||||||
"kld": qp.error.kld(true_prev, estim_prev, eps=_eps),
|
|
||||||
"nkld": qp.error.nkld(true_prev, estim_prev, eps=_eps),
|
|
||||||
"true_f1e": qf1e_binary(true_prev),
|
|
||||||
"estim_f1e": qf1e_binary(estim_prev),
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def extend_and_quantify(
|
|
||||||
model,
|
|
||||||
q_model,
|
|
||||||
train,
|
|
||||||
test: LabelledCollection | AbstractStochasticSeededProtocol,
|
|
||||||
):
|
|
||||||
model.fit(*train.Xy)
|
|
||||||
|
|
||||||
pred_prob_train = cross_val_predict(model, *train.Xy, method="predict_proba")
|
|
||||||
_train = extend_collection(train, pred_prob_train)
|
|
||||||
|
|
||||||
q_model.fit(_train)
|
|
||||||
|
|
||||||
def quantify_extended(test):
|
|
||||||
pred_prob_test = model.predict_proba(test.X)
|
|
||||||
_test = extend_collection(test, pred_prob_test)
|
|
||||||
_estim_prev = q_model.quantify(_test.instances)
|
|
||||||
# check that _estim_prev has all the classes and eventually fill the missing
|
|
||||||
# ones with 0
|
|
||||||
for _cls in _test.classes_:
|
|
||||||
if _cls not in q_model.classes_:
|
|
||||||
_estim_prev = np.insert(_estim_prev, _cls, [0.0], axis=0)
|
|
||||||
print(_estim_prev)
|
|
||||||
return _test.prevalence(), _estim_prev
|
|
||||||
|
|
||||||
if isinstance(test, LabelledCollection):
|
|
||||||
_true_prev, _estim_prev = quantify_extended(test)
|
|
||||||
_errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0])
|
|
||||||
return ([test.prevalence()], [_true_prev], [_estim_prev], [_errors])
|
|
||||||
|
|
||||||
elif isinstance(test, AbstractStochasticSeededProtocol):
|
|
||||||
orig_prevs, true_prevs, estim_prevs, errors = [], [], [], []
|
|
||||||
for index in test.samples_parameters():
|
|
||||||
sample = test.sample(index)
|
|
||||||
_true_prev, _estim_prev = quantify_extended(sample)
|
|
||||||
|
|
||||||
orig_prevs.append(sample.prevalence())
|
|
||||||
true_prevs.append(_true_prev)
|
|
||||||
estim_prevs.append(_estim_prev)
|
|
||||||
errors.append(compute_errors(_true_prev, _estim_prev, sample.X.shape[0]))
|
|
||||||
|
|
||||||
return orig_prevs, true_prevs, estim_prevs, errors
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_1(dataset_name):
|
|
||||||
train, test = get_dataset(dataset_name)
|
|
||||||
|
|
||||||
orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify(
|
|
||||||
LogisticRegression(),
|
|
||||||
SLD(LogisticRegression()),
|
|
||||||
train,
|
|
||||||
APP(test, n_prevalences=11, repeats=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
for orig_prev, true_prev, estim_prev, _errors in zip(
|
|
||||||
orig_prevs, true_prevs, estim_prevs, errors
|
|
||||||
):
|
|
||||||
print(f"original prevalence:\t{orig_prev}")
|
|
||||||
print(f"true prevalence:\t{true_prev}")
|
|
||||||
print(f"estimated prevalence:\t{estim_prev}")
|
|
||||||
for name, err in _errors.items():
|
|
||||||
print(f"{name}={err:.3f}")
|
|
||||||
print()
|
|
|
@ -1,20 +0,0 @@
|
||||||
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from quacc.evaluation.baseline import kfcv, trust_score
|
|
||||||
from quacc.dataset import get_spambase
|
|
||||||
|
|
||||||
|
|
||||||
class TestBaseline:
|
|
||||||
|
|
||||||
def test_kfcv(self):
|
|
||||||
train, validation, _ = get_spambase()
|
|
||||||
c_model = LogisticRegression()
|
|
||||||
c_model.fit(train.X, train.y)
|
|
||||||
assert "f1_score" in kfcv(c_model, validation)
|
|
||||||
|
|
||||||
def test_trust_score(self):
|
|
||||||
train, validation, test = get_spambase()
|
|
||||||
c_model = LogisticRegression()
|
|
||||||
c_model.fit(train.X, train.y)
|
|
||||||
trustscore = trust_score(c_model, train, test)
|
|
||||||
assert len(trustscore) == len(test.y)
|
|
Binary file not shown.
|
@ -0,0 +1,12 @@
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
|
from quacc.dataset import Dataset
|
||||||
|
from quacc.evaluation.baseline import kfcv
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseline:
|
||||||
|
def test_kfcv(self):
|
||||||
|
spambase = Dataset("spambase", n_prevalences=1).get_raw()
|
||||||
|
c_model = LogisticRegression()
|
||||||
|
c_model.fit(spambase.train.X, spambase.train.y)
|
||||||
|
assert "f1_score" in kfcv(c_model, spambase.validation)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,12 +1,12 @@
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import scipy.sparse as sp
|
import scipy.sparse as sp
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
from quacc.estimator import BinaryQuantifierAccuracyEstimator
|
from quacc.method.base import BinaryQuantifierAccuracyEstimator
|
||||||
|
|
||||||
|
|
||||||
class TestBinaryQuantifierAccuracyEstimator:
|
class TestBQAE:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"instances,preds0,preds1,result",
|
"instances,preds0,preds1,result",
|
||||||
[
|
[
|
|
@ -0,0 +1,2 @@
|
||||||
|
class TestMCAE:
|
||||||
|
pass
|
Loading…
Reference in New Issue