bin adapted to grid search
This commit is contained in:
parent
eccd818719
commit
d1be2b72e8
|
@ -1,14 +1,5 @@
|
||||||
{
|
{
|
||||||
"todo": [
|
"todo": [
|
||||||
{
|
|
||||||
"assignedTo": {
|
|
||||||
"name": "Lorenzo Volpi"
|
|
||||||
},
|
|
||||||
"creation_time": "2023-10-28T14:34:46.226Z",
|
|
||||||
"id": "4",
|
|
||||||
"references": [],
|
|
||||||
"title": "Aggingere estimator basati su PACC (quantificatore)"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"assignedTo": {
|
"assignedTo": {
|
||||||
"name": "Lorenzo Volpi"
|
"name": "Lorenzo Volpi"
|
||||||
|
@ -18,15 +9,6 @@
|
||||||
"references": [],
|
"references": [],
|
||||||
"title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence"
|
"title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"assignedTo": {
|
|
||||||
"name": "Lorenzo Volpi"
|
|
||||||
},
|
|
||||||
"creation_time": "2023-10-28T14:34:23.217Z",
|
|
||||||
"id": "3",
|
|
||||||
"references": [],
|
|
||||||
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"assignedTo": {
|
"assignedTo": {
|
||||||
"name": "Lorenzo Volpi"
|
"name": "Lorenzo Volpi"
|
||||||
|
@ -38,6 +20,27 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"in-progress": [
|
"in-progress": [
|
||||||
|
{
|
||||||
|
"assignedTo": {
|
||||||
|
"name": "Lorenzo Volpi"
|
||||||
|
},
|
||||||
|
"creation_time": "2023-10-28T14:34:23.217Z",
|
||||||
|
"id": "3",
|
||||||
|
"references": [],
|
||||||
|
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"assignedTo": {
|
||||||
|
"name": "Lorenzo Volpi"
|
||||||
|
},
|
||||||
|
"creation_time": "2023-10-28T14:34:46.226Z",
|
||||||
|
"id": "4",
|
||||||
|
"references": [],
|
||||||
|
"title": "Aggingere estimator basati su PACC (quantificatore)"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"testing": [],
|
||||||
|
"done": [
|
||||||
{
|
{
|
||||||
"assignedTo": {
|
"assignedTo": {
|
||||||
"name": "Lorenzo Volpi"
|
"name": "Lorenzo Volpi"
|
||||||
|
@ -47,7 +50,5 @@
|
||||||
"references": [],
|
"references": [],
|
||||||
"title": "Rework rappresentazione dati di report"
|
"title": "Rework rappresentazione dati di report"
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"testing": [],
|
|
||||||
"done": []
|
|
||||||
}
|
}
|
29
conf.yaml
29
conf.yaml
|
@ -12,8 +12,9 @@ debug_conf: &debug_conf
|
||||||
plot_confs:
|
plot_confs:
|
||||||
debug:
|
debug:
|
||||||
PLOT_ESTIMATORS:
|
PLOT_ESTIMATORS:
|
||||||
- mul_sld_gs
|
- mul_sld
|
||||||
- ref
|
- ref
|
||||||
|
- atc_mc
|
||||||
PLOT_STDEV: true
|
PLOT_STDEV: true
|
||||||
|
|
||||||
test_conf: &test_conf
|
test_conf: &test_conf
|
||||||
|
@ -21,24 +22,28 @@ test_conf: &test_conf
|
||||||
METRICS:
|
METRICS:
|
||||||
- acc
|
- acc
|
||||||
- f1
|
- f1
|
||||||
DATASET_N_PREVS: 2
|
DATASET_N_PREVS: 9
|
||||||
DATASET_PREVS:
|
|
||||||
- 0.5
|
|
||||||
- 0.1
|
|
||||||
|
|
||||||
confs:
|
confs:
|
||||||
# - DATASET_NAME: rcv1
|
- DATASET_NAME: rcv1
|
||||||
# DATASET_TARGET: CCAT
|
DATASET_TARGET: CCAT
|
||||||
- DATASET_NAME: imdb
|
# - DATASET_NAME: imdb
|
||||||
|
|
||||||
plot_confs:
|
plot_confs:
|
||||||
best_vs_atc:
|
2gs_vs_atc:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- bin_sld_gs
|
||||||
|
- bin_sld_qgs
|
||||||
|
- mul_sld_gs
|
||||||
|
- mul_sld_qgs
|
||||||
|
- ref
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
sld_vs_pacc:
|
||||||
PLOT_ESTIMATORS:
|
PLOT_ESTIMATORS:
|
||||||
- bin_sld
|
- bin_sld
|
||||||
- bin_sld_bcts
|
|
||||||
- bin_sld_gs
|
- bin_sld_gs
|
||||||
- mul_sld
|
- mul_sld
|
||||||
- mul_sld_bcts
|
|
||||||
- mul_sld_gs
|
- mul_sld_gs
|
||||||
- ref
|
- ref
|
||||||
- atc_mc
|
- atc_mc
|
||||||
|
@ -102,4 +107,4 @@ main_conf: &main_conf
|
||||||
- atc_ne
|
- atc_ne
|
||||||
- doc_feat
|
- doc_feat
|
||||||
|
|
||||||
exec: *debug_conf
|
exec: *test_conf
|
94
quacc.log
94
quacc.log
|
@ -1494,3 +1494,97 @@
|
||||||
01/11/23 13:07:27| INFO Dataset sample 0.50 of dataset imdb_1prevs started
|
01/11/23 13:07:27| INFO Dataset sample 0.50 of dataset imdb_1prevs started
|
||||||
01/11/23 13:07:27| ERROR Evaluation over imdb_1prevs failed. Exception: 'Invalid estimator: estimator mul_sld_gs does not exist'
|
01/11/23 13:07:27| ERROR Evaluation over imdb_1prevs failed. Exception: 'Invalid estimator: estimator mul_sld_gs does not exist'
|
||||||
01/11/23 13:07:27| ERROR Failed while saving configuration imdb_debug of imdb_1prevs. Exception: cannot access local variable 'dr' where it is not associated with a value
|
01/11/23 13:07:27| ERROR Failed while saving configuration imdb_debug of imdb_1prevs. Exception: cannot access local variable 'dr' where it is not associated with a value
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
03/11/23 20:54:19| INFO dataset rcv1_CCAT_9prevs
|
||||||
|
03/11/23 20:54:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 20:54:28| WARNING Method mul_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib'].
|
||||||
|
03/11/23 20:54:29| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor'
|
||||||
|
03/11/23 20:54:30| WARNING Method bin_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib'].
|
||||||
|
03/11/23 20:55:09| INFO ref finished [took 38.5179s]
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
03/11/23 21:28:36| INFO dataset rcv1_CCAT_9prevs
|
||||||
|
03/11/23 21:28:41| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 21:28:45| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor'
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
03/11/23 21:31:03| INFO dataset rcv1_CCAT_9prevs
|
||||||
|
03/11/23 21:31:08| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 21:31:59| INFO ref finished [took 45.6616s]
|
||||||
|
03/11/23 21:32:03| INFO atc_mc finished [took 48.4360s]
|
||||||
|
03/11/23 21:32:07| INFO atc_ne finished [took 51.0515s]
|
||||||
|
03/11/23 21:32:23| INFO mul_sld finished [took 72.9229s]
|
||||||
|
03/11/23 21:34:43| INFO bin_sld finished [took 213.9538s]
|
||||||
|
03/11/23 21:36:27| INFO mul_sld_gs finished [took 314.9357s]
|
||||||
|
03/11/23 21:40:50| INFO bin_sld_gs finished [took 579.2530s]
|
||||||
|
03/11/23 21:40:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 582.5876s]
|
||||||
|
03/11/23 21:40:50| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 21:41:39| INFO ref finished [took 43.7409s]
|
||||||
|
03/11/23 21:41:43| INFO atc_mc finished [took 46.4580s]
|
||||||
|
03/11/23 21:41:44| INFO atc_ne finished [took 46.4267s]
|
||||||
|
03/11/23 21:41:54| INFO mul_sld finished [took 61.3005s]
|
||||||
|
03/11/23 21:44:18| INFO bin_sld finished [took 206.3680s]
|
||||||
|
03/11/23 21:45:59| INFO mul_sld_gs finished [took 304.4726s]
|
||||||
|
03/11/23 21:50:33| INFO bin_sld_gs finished [took 579.3455s]
|
||||||
|
03/11/23 21:50:33| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 582.4808s]
|
||||||
|
03/11/23 21:50:33| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 21:51:22| INFO ref finished [took 43.6853s]
|
||||||
|
03/11/23 21:51:26| INFO atc_mc finished [took 47.1366s]
|
||||||
|
03/11/23 21:51:30| INFO atc_ne finished [took 49.4868s]
|
||||||
|
03/11/23 21:51:34| INFO mul_sld finished [took 59.0964s]
|
||||||
|
03/11/23 21:53:59| INFO bin_sld finished [took 205.0248s]
|
||||||
|
03/11/23 21:55:50| INFO mul_sld_gs finished [took 312.5630s]
|
||||||
|
03/11/23 22:00:27| INFO bin_sld_gs finished [took 591.1460s]
|
||||||
|
03/11/23 22:00:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 594.3163s]
|
||||||
|
03/11/23 22:00:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 22:01:15| INFO ref finished [took 43.3806s]
|
||||||
|
03/11/23 22:01:19| INFO atc_mc finished [took 46.6674s]
|
||||||
|
03/11/23 22:01:21| INFO atc_ne finished [took 47.1220s]
|
||||||
|
03/11/23 22:01:28| INFO mul_sld finished [took 58.6799s]
|
||||||
|
03/11/23 22:03:53| INFO bin_sld finished [took 204.7659s]
|
||||||
|
03/11/23 22:05:39| INFO mul_sld_gs finished [took 307.8811s]
|
||||||
|
03/11/23 22:10:32| INFO bin_sld_gs finished [took 601.9995s]
|
||||||
|
03/11/23 22:10:32| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 604.8406s]
|
||||||
|
03/11/23 22:10:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 22:11:20| INFO ref finished [took 42.8256s]
|
||||||
|
03/11/23 22:11:25| INFO atc_mc finished [took 46.9203s]
|
||||||
|
03/11/23 22:11:28| INFO atc_ne finished [took 49.3042s]
|
||||||
|
03/11/23 22:11:34| INFO mul_sld finished [took 60.2744s]
|
||||||
|
03/11/23 22:13:59| INFO bin_sld finished [took 205.7078s]
|
||||||
|
03/11/23 22:15:45| INFO mul_sld_gs finished [took 309.0888s]
|
||||||
|
03/11/23 22:20:32| INFO bin_sld_gs finished [took 596.5102s]
|
||||||
|
03/11/23 22:20:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 599.5067s]
|
||||||
|
03/11/23 22:20:32| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 22:21:20| INFO ref finished [took 43.1698s]
|
||||||
|
03/11/23 22:21:24| INFO atc_mc finished [took 46.5768s]
|
||||||
|
03/11/23 22:21:25| INFO atc_ne finished [took 46.3408s]
|
||||||
|
03/11/23 22:21:34| INFO mul_sld finished [took 60.8070s]
|
||||||
|
03/11/23 22:23:58| INFO bin_sld finished [took 205.3362s]
|
||||||
|
03/11/23 22:25:44| INFO mul_sld_gs finished [took 308.1859s]
|
||||||
|
03/11/23 22:30:44| INFO bin_sld_gs finished [took 609.5468s]
|
||||||
|
03/11/23 22:30:44| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 612.5803s]
|
||||||
|
03/11/23 22:30:44| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 22:31:32| INFO ref finished [took 43.2949s]
|
||||||
|
03/11/23 22:31:37| INFO atc_mc finished [took 46.3686s]
|
||||||
|
03/11/23 22:31:40| INFO atc_ne finished [took 49.2242s]
|
||||||
|
03/11/23 22:31:47| INFO mul_sld finished [took 60.9437s]
|
||||||
|
03/11/23 22:34:11| INFO bin_sld finished [took 205.9299s]
|
||||||
|
03/11/23 22:35:56| INFO mul_sld_gs finished [took 308.2738s]
|
||||||
|
03/11/23 22:40:36| INFO bin_sld_gs finished [took 588.7918s]
|
||||||
|
03/11/23 22:40:36| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 591.8830s]
|
||||||
|
03/11/23 22:40:36| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 22:41:24| INFO ref finished [took 43.3321s]
|
||||||
|
03/11/23 22:41:29| INFO atc_mc finished [took 46.8041s]
|
||||||
|
03/11/23 22:41:29| INFO atc_ne finished [took 46.5810s]
|
||||||
|
03/11/23 22:41:38| INFO mul_sld finished [took 60.2962s]
|
||||||
|
03/11/23 22:44:07| INFO bin_sld finished [took 209.6435s]
|
||||||
|
03/11/23 22:45:44| INFO mul_sld_gs finished [took 304.4809s]
|
||||||
|
03/11/23 22:50:39| INFO bin_sld_gs finished [took 599.5588s]
|
||||||
|
03/11/23 22:50:39| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 602.5720s]
|
||||||
|
03/11/23 22:50:39| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
|
||||||
|
03/11/23 22:51:26| INFO ref finished [took 42.4313s]
|
||||||
|
03/11/23 22:51:30| INFO atc_mc finished [took 45.5261s]
|
||||||
|
03/11/23 22:51:34| INFO atc_ne finished [took 48.4488s]
|
||||||
|
03/11/23 22:51:47| INFO mul_sld finished [took 66.4801s]
|
||||||
|
03/11/23 22:54:08| INFO bin_sld finished [took 208.4272s]
|
||||||
|
03/11/23 22:55:49| INFO mul_sld_gs finished [took 306.4505s]
|
||||||
|
03/11/23 23:00:15| INFO bin_sld_gs finished [took 573.7761s]
|
||||||
|
03/11/23 23:00:15| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 576.7586s]
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||||
|
|
||||||
|
import quacc as qc
|
||||||
|
|
||||||
|
from ..method.base import BaseAccuracyEstimator
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
estimator: BaseAccuracyEstimator,
|
||||||
|
protocol: AbstractProtocol,
|
||||||
|
error_metric: Union[Callable | str],
|
||||||
|
) -> float:
|
||||||
|
if isinstance(error_metric, str):
|
||||||
|
error_metric = qc.error.from_name(error_metric)
|
||||||
|
|
||||||
|
collator_bck_ = protocol.collator
|
||||||
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
|
estim_prevs, true_prevs = [], []
|
||||||
|
for sample in protocol():
|
||||||
|
e_sample = estimator.extend(sample)
|
||||||
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||||
|
estim_prevs.append(estim_prev)
|
||||||
|
true_prevs.append(e_sample.prevalence())
|
||||||
|
|
||||||
|
protocol.collator = collator_bck_
|
||||||
|
|
||||||
|
true_prevs = np.array(true_prevs)
|
||||||
|
estim_prevs = np.array(estim_prevs)
|
||||||
|
|
||||||
|
return error_metric(true_prevs, estim_prevs)
|
|
@ -1,9 +1,9 @@
|
||||||
|
import inspect
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from quapy.method.aggregative import SLD
|
from quapy.method.aggregative import SLD
|
||||||
from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
|
from quapy.protocol import UPP, AbstractProtocol
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
import quacc as qc
|
import quacc as qc
|
||||||
|
@ -25,38 +25,12 @@ def method(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
estimator: BaseAccuracyEstimator,
|
|
||||||
protocol: AbstractProtocol,
|
|
||||||
error_metric: Union[Callable | str],
|
|
||||||
) -> float:
|
|
||||||
if isinstance(error_metric, str):
|
|
||||||
error_metric = qc.error.from_name(error_metric)
|
|
||||||
|
|
||||||
collator_bck_ = protocol.collator
|
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
||||||
|
|
||||||
estim_prevs, true_prevs = [], []
|
|
||||||
for sample in protocol():
|
|
||||||
e_sample = estimator.extend(sample)
|
|
||||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
||||||
estim_prevs.append(estim_prev)
|
|
||||||
true_prevs.append(e_sample.prevalence())
|
|
||||||
|
|
||||||
protocol.collator = collator_bck_
|
|
||||||
|
|
||||||
true_prevs = np.array(true_prevs)
|
|
||||||
estim_prevs = np.array(estim_prevs)
|
|
||||||
|
|
||||||
return error_metric(true_prevs, estim_prevs)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluation_report(
|
def evaluation_report(
|
||||||
estimator: BaseAccuracyEstimator,
|
estimator: BaseAccuracyEstimator,
|
||||||
protocol: AbstractProtocol,
|
protocol: AbstractProtocol,
|
||||||
method: str,
|
|
||||||
) -> EvaluationReport:
|
) -> EvaluationReport:
|
||||||
report = EvaluationReport(name=method)
|
method_name = inspect.stack()[1].function
|
||||||
|
report = EvaluationReport(name=method_name)
|
||||||
for sample in protocol():
|
for sample in protocol():
|
||||||
e_sample = estimator.extend(sample)
|
e_sample = estimator.extend(sample)
|
||||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||||
|
@ -80,7 +54,6 @@ def bin_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluation_report(
|
return evaluation_report(
|
||||||
estimator=est,
|
estimator=est,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
method="bin_sld",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,8 +63,7 @@ def mul_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
est.fit(validation)
|
est.fit(validation)
|
||||||
return evaluation_report(
|
return evaluation_report(
|
||||||
estimator=est,
|
estimator=est,
|
||||||
protocor=protocol,
|
protocol=protocol,
|
||||||
method="mul_sld",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,7 +74,6 @@ def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluation_report(
|
return evaluation_report(
|
||||||
estimator=est,
|
estimator=est,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
method="bin_sld_bcts",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,14 +84,13 @@ def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluation_report(
|
return evaluation_report(
|
||||||
estimator=est,
|
estimator=est,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
method="mul_sld_bcts",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||||
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||||
model = SLD(LogisticRegression())
|
model = BQAE(c_model, SLD(LogisticRegression()))
|
||||||
est = GridSearchAE(
|
est = GridSearchAE(
|
||||||
model=model,
|
model=model,
|
||||||
param_grid={
|
param_grid={
|
||||||
|
@ -130,10 +100,30 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||||
},
|
},
|
||||||
refit=False,
|
refit=False,
|
||||||
protocol=UPP(v_val, repeats=100),
|
protocol=UPP(v_val, repeats=100),
|
||||||
verbose=True,
|
verbose=False,
|
||||||
|
).fit(v_train)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||||
|
model = MCAE(c_model, SLD(LogisticRegression()))
|
||||||
|
est = GridSearchAE(
|
||||||
|
model=model,
|
||||||
|
param_grid={
|
||||||
|
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||||
|
"q__classifier__class_weight": [None, "balanced"],
|
||||||
|
"q__recalib": [None, "bcts", "vs"],
|
||||||
|
},
|
||||||
|
refit=False,
|
||||||
|
protocol=UPP(v_val, repeats=100),
|
||||||
|
verbose=False,
|
||||||
).fit(v_train)
|
).fit(v_train)
|
||||||
return evaluation_report(
|
return evaluation_report(
|
||||||
estimator=est,
|
estimator=est,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
method="mul_sld_gs",
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from sklearn.linear_model import LogisticRegression
|
||||||
from quacc.dataset import Dataset
|
from quacc.dataset import Dataset
|
||||||
from quacc.error import acc
|
from quacc.error import acc
|
||||||
from quacc.evaluation.report import CompReport, EvaluationReport
|
from quacc.evaluation.report import CompReport, EvaluationReport
|
||||||
from quacc.method.base import MultiClassAccuracyEstimator
|
from quacc.method.base import BinaryQuantifierAccuracyEstimator
|
||||||
from quacc.method.model_selection import GridSearchAE
|
from quacc.method.model_selection import GridSearchAE
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,8 +21,8 @@ def test_gs():
|
||||||
classifier.fit(*d.train.Xy)
|
classifier.fit(*d.train.Xy)
|
||||||
|
|
||||||
quantifier = SLD(LogisticRegression())
|
quantifier = SLD(LogisticRegression())
|
||||||
estimator = MultiClassAccuracyEstimator(classifier, quantifier)
|
# estimator = MultiClassAccuracyEstimator(classifier, quantifier)
|
||||||
estimator.fit(d.validation)
|
estimator = BinaryQuantifierAccuracyEstimator(classifier, quantifier)
|
||||||
|
|
||||||
v_train, v_val = d.validation.split_stratified(0.6, random_state=0)
|
v_train, v_val = d.validation.split_stratified(0.6, random_state=0)
|
||||||
gs_protocol = UPP(v_val, sample_size=1000, repeats=100)
|
gs_protocol = UPP(v_val, sample_size=1000, repeats=100)
|
||||||
|
@ -31,13 +31,15 @@ def test_gs():
|
||||||
param_grid={
|
param_grid={
|
||||||
"q__classifier__C": np.logspace(-3, 3, 7),
|
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||||
"q__classifier__class_weight": [None, "balanced"],
|
"q__classifier__class_weight": [None, "balanced"],
|
||||||
"q__recalib": [None, "bcts", "vs"],
|
"q__recalib": [None, "bcts", "ts"],
|
||||||
},
|
},
|
||||||
refit=False,
|
refit=False,
|
||||||
protocol=gs_protocol,
|
protocol=gs_protocol,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
).fit(v_train)
|
).fit(v_train)
|
||||||
|
|
||||||
|
estimator.fit(d.validation)
|
||||||
|
|
||||||
tstart = time()
|
tstart = time()
|
||||||
erb, ergs = EvaluationReport("base"), EvaluationReport("gs")
|
erb, ergs = EvaluationReport("base"), EvaluationReport("gs")
|
||||||
protocol = APP(
|
protocol = APP(
|
||||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -1,15 +1,13 @@
|
||||||
import math
|
import math
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import quapy as qp
|
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.method.aggregative import CC, SLD, BaseQuantifier
|
from quapy.method.aggregative import BaseQuantifier
|
||||||
from quapy.model_selection import GridSearchQ
|
from scipy.sparse import csr_matrix
|
||||||
from quapy.protocol import UPP
|
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from sklearn.model_selection import cross_val_predict
|
|
||||||
|
|
||||||
from quacc.data import ExtendedCollection
|
from quacc.data import ExtendedCollection
|
||||||
|
|
||||||
|
@ -20,9 +18,7 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
||||||
classifier: BaseEstimator,
|
classifier: BaseEstimator,
|
||||||
quantifier: BaseQuantifier,
|
quantifier: BaseQuantifier,
|
||||||
):
|
):
|
||||||
self.fit_score = None
|
|
||||||
self.__check_classifier(classifier)
|
self.__check_classifier(classifier)
|
||||||
self.classifier = classifier
|
|
||||||
self.quantifier = quantifier
|
self.quantifier = quantifier
|
||||||
|
|
||||||
def __check_classifier(self, classifier):
|
def __check_classifier(self, classifier):
|
||||||
|
@ -30,21 +26,7 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities."
|
f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities."
|
||||||
)
|
)
|
||||||
|
self.classifier = classifier
|
||||||
def _gs_params(self, t_val: LabelledCollection):
|
|
||||||
return {
|
|
||||||
"param_grid": {
|
|
||||||
"classifier__C": np.logspace(-3, 3, 7),
|
|
||||||
"classifier__class_weight": [None, "balanced"],
|
|
||||||
"recalib": [None, "bcts"],
|
|
||||||
},
|
|
||||||
"protocol": UPP(t_val, repeats=1000),
|
|
||||||
"error": qp.error.mae,
|
|
||||||
"refit": False,
|
|
||||||
"timeout": -1,
|
|
||||||
"n_jobs": None,
|
|
||||||
"verbose": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||||
if not pred_proba:
|
if not pred_proba:
|
||||||
|
@ -67,6 +49,7 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
quantifier: BaseQuantifier,
|
quantifier: BaseQuantifier,
|
||||||
):
|
):
|
||||||
super().__init__(classifier, quantifier)
|
super().__init__(classifier, quantifier)
|
||||||
|
self.e_train = None
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection):
|
def fit(self, train: LabelledCollection):
|
||||||
pred_probs = self.classifier.predict_proba(train.X)
|
pred_probs = self.classifier.predict_proba(train.X)
|
||||||
|
@ -95,84 +78,52 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
|
|
||||||
|
|
||||||
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
|
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
|
def __init__(self, classifier: BaseEstimator, quantifier: BaseAccuracyEstimator):
|
||||||
super().__init__()
|
super().__init__(classifier, quantifier)
|
||||||
self.c_model = c_model
|
self.quantifiers = []
|
||||||
self._q_model_name = q_model.upper()
|
self.e_trains = []
|
||||||
self.q_models = []
|
|
||||||
self.gs = gs
|
|
||||||
self.recalib = recalib
|
|
||||||
self.e_train = None
|
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
# check if model is fit
|
pred_probs = self.classifier.predict_proba(train.X)
|
||||||
# self.model.fit(*train.Xy)
|
self.e_train = ExtendedCollection.extend_collection(train, pred_probs)
|
||||||
if isinstance(train, LabelledCollection):
|
|
||||||
pred_prob_train = cross_val_predict(
|
|
||||||
self.c_model, *train.Xy, method="predict_proba"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
|
||||||
elif isinstance(train, ExtendedCollection):
|
|
||||||
self.e_train = train
|
|
||||||
|
|
||||||
self.n_classes = self.e_train.n_classes
|
self.n_classes = self.e_train.n_classes
|
||||||
e_trains = self.e_train.split_by_pred()
|
self.e_trains = self.e_train.split_by_pred()
|
||||||
|
self.quantifiers = [deepcopy(self.quantifier) for _ in self.e_trains]
|
||||||
|
|
||||||
if self._q_model_name == "SLD":
|
self.quantifiers = []
|
||||||
fit_scores = []
|
for train in self.e_trains:
|
||||||
for e_train in e_trains:
|
quant = deepcopy(self.quantifier)
|
||||||
if self.gs:
|
quant.fit(train)
|
||||||
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
|
self.quantifiers.append(quant)
|
||||||
gs_params = self._gs_params(t_val)
|
|
||||||
q_model = GridSearchQ(
|
|
||||||
SLD(LogisticRegression()),
|
|
||||||
**gs_params,
|
|
||||||
)
|
|
||||||
q_model.fit(t_train)
|
|
||||||
fit_scores.append(q_model.best_score_)
|
|
||||||
self.q_models.append(q_model)
|
|
||||||
else:
|
|
||||||
q_model = SLD(LogisticRegression(), recalib=self.recalib)
|
|
||||||
q_model.fit(e_train)
|
|
||||||
self.q_models.append(q_model)
|
|
||||||
|
|
||||||
if self.gs:
|
|
||||||
self.fit_score = np.mean(fit_scores)
|
|
||||||
|
|
||||||
elif self._q_model_name == "CC":
|
|
||||||
for e_train in e_trains:
|
|
||||||
q_model = CC(LogisticRegression())
|
|
||||||
q_model.fit(e_train)
|
|
||||||
self.q_models.append(q_model)
|
|
||||||
|
|
||||||
def estimate(self, instances, ext=False):
|
def estimate(self, instances, ext=False):
|
||||||
# TODO: test
|
# TODO: test
|
||||||
if not ext:
|
|
||||||
pred_prob = self.c_model.predict_proba(instances)
|
|
||||||
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
|
||||||
else:
|
|
||||||
e_inst = instances
|
e_inst = instances
|
||||||
|
if not ext:
|
||||||
|
pred_prob = self.classifier.predict_proba(instances)
|
||||||
|
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
||||||
|
|
||||||
_ncl = int(math.sqrt(self.n_classes))
|
_ncl = int(math.sqrt(self.n_classes))
|
||||||
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
|
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
|
||||||
estim_prevs = [
|
estim_prevs = self._quantify_helper(s_inst, norms)
|
||||||
self._quantify_helper(inst, norm, q_model)
|
|
||||||
for (inst, norm, q_model) in zip(s_inst, norms, self.q_models)
|
|
||||||
]
|
|
||||||
|
|
||||||
estim_prev = []
|
estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten()
|
||||||
for prev_row in zip(*estim_prevs):
|
return estim_prev
|
||||||
for prev in prev_row:
|
|
||||||
estim_prev.append(prev)
|
|
||||||
|
|
||||||
return np.asarray(estim_prev)
|
def _quantify_helper(
|
||||||
|
self,
|
||||||
def _quantify_helper(self, inst, norm, q_model):
|
s_inst: List[np.ndarray | csr_matrix],
|
||||||
|
norms: List[float],
|
||||||
|
):
|
||||||
|
estim_prevs = []
|
||||||
|
for quant, inst, norm in zip(self.quantifiers, s_inst, norms):
|
||||||
if inst.shape[0] > 0:
|
if inst.shape[0] > 0:
|
||||||
return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst))))
|
estim_prevs.append(quant.quantify(inst) * norm)
|
||||||
else:
|
else:
|
||||||
return np.asarray([0.0, 0.0])
|
estim_prevs.append(np.asarray([0.0, 0.0]))
|
||||||
|
|
||||||
|
return estim_prevs
|
||||||
|
|
||||||
|
|
||||||
BAE = BaseAccuracyEstimator
|
BAE = BaseAccuracyEstimator
|
||||||
|
|
|
@ -7,8 +7,9 @@ from quapy.data import LabelledCollection
|
||||||
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||||
|
|
||||||
import quacc as qc
|
import quacc as qc
|
||||||
import quacc.evaluation.method as evaluation
|
import quacc.error
|
||||||
from quacc.data import ExtendedCollection
|
from quacc.data import ExtendedCollection
|
||||||
|
from quacc.evaluation import evaluate
|
||||||
from quacc.method.base import BaseAccuracyEstimator
|
from quacc.method.base import BaseAccuracyEstimator
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,8 +139,9 @@ class GridSearchAE(BaseAccuracyEstimator):
|
||||||
model = deepcopy(self.model)
|
model = deepcopy(self.model)
|
||||||
# overrides default parameters with the parameters being explored at this iteration
|
# overrides default parameters with the parameters being explored at this iteration
|
||||||
model.set_params(**params)
|
model.set_params(**params)
|
||||||
|
# print({k: v for k, v in model.get_params().items() if k in params})
|
||||||
model.fit(training)
|
model.fit(training)
|
||||||
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
|
score = evaluate(model, protocol=protocol, error_metric=error)
|
||||||
|
|
||||||
ttime = time() - tinit
|
ttime = time() - tinit
|
||||||
self._sout(
|
self._sout(
|
||||||
|
@ -157,7 +159,6 @@ class GridSearchAE(BaseAccuracyEstimator):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._sout(f"something went wrong for config {params}; skipping:")
|
self._sout(f"something went wrong for config {params}; skipping:")
|
||||||
self._sout(f"\tException: {e}")
|
self._sout(f"\tException: {e}")
|
||||||
# traceback(e)
|
|
||||||
score = None
|
score = None
|
||||||
|
|
||||||
return params, score, model
|
return params, score, model
|
||||||
|
|
Loading…
Reference in New Issue