bin adapted to grid search
This commit is contained in:
parent
eccd818719
commit
d1be2b72e8
|
@ -1,14 +1,5 @@
|
|||
{
|
||||
"todo": [
|
||||
{
|
||||
"assignedTo": {
|
||||
"name": "Lorenzo Volpi"
|
||||
},
|
||||
"creation_time": "2023-10-28T14:34:46.226Z",
|
||||
"id": "4",
|
||||
"references": [],
|
||||
"title": "Aggingere estimator basati su PACC (quantificatore)"
|
||||
},
|
||||
{
|
||||
"assignedTo": {
|
||||
"name": "Lorenzo Volpi"
|
||||
|
@ -18,15 +9,6 @@
|
|||
"references": [],
|
||||
"title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence"
|
||||
},
|
||||
{
|
||||
"assignedTo": {
|
||||
"name": "Lorenzo Volpi"
|
||||
},
|
||||
"creation_time": "2023-10-28T14:34:23.217Z",
|
||||
"id": "3",
|
||||
"references": [],
|
||||
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
|
||||
},
|
||||
{
|
||||
"assignedTo": {
|
||||
"name": "Lorenzo Volpi"
|
||||
|
@ -38,6 +20,27 @@
|
|||
}
|
||||
],
|
||||
"in-progress": [
|
||||
{
|
||||
"assignedTo": {
|
||||
"name": "Lorenzo Volpi"
|
||||
},
|
||||
"creation_time": "2023-10-28T14:34:23.217Z",
|
||||
"id": "3",
|
||||
"references": [],
|
||||
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
|
||||
},
|
||||
{
|
||||
"assignedTo": {
|
||||
"name": "Lorenzo Volpi"
|
||||
},
|
||||
"creation_time": "2023-10-28T14:34:46.226Z",
|
||||
"id": "4",
|
||||
"references": [],
|
||||
"title": "Aggingere estimator basati su PACC (quantificatore)"
|
||||
}
|
||||
],
|
||||
"testing": [],
|
||||
"done": [
|
||||
{
|
||||
"assignedTo": {
|
||||
"name": "Lorenzo Volpi"
|
||||
|
@ -47,7 +50,5 @@
|
|||
"references": [],
|
||||
"title": "Rework rappresentazione dati di report"
|
||||
}
|
||||
],
|
||||
"testing": [],
|
||||
"done": []
|
||||
]
|
||||
}
|
29
conf.yaml
29
conf.yaml
|
@ -12,8 +12,9 @@ debug_conf: &debug_conf
|
|||
plot_confs:
|
||||
debug:
|
||||
PLOT_ESTIMATORS:
|
||||
- mul_sld_gs
|
||||
- mul_sld
|
||||
- ref
|
||||
- atc_mc
|
||||
PLOT_STDEV: true
|
||||
|
||||
test_conf: &test_conf
|
||||
|
@ -21,24 +22,28 @@ test_conf: &test_conf
|
|||
METRICS:
|
||||
- acc
|
||||
- f1
|
||||
DATASET_N_PREVS: 2
|
||||
DATASET_PREVS:
|
||||
- 0.5
|
||||
- 0.1
|
||||
DATASET_N_PREVS: 9
|
||||
|
||||
confs:
|
||||
# - DATASET_NAME: rcv1
|
||||
# DATASET_TARGET: CCAT
|
||||
- DATASET_NAME: imdb
|
||||
- DATASET_NAME: rcv1
|
||||
DATASET_TARGET: CCAT
|
||||
# - DATASET_NAME: imdb
|
||||
|
||||
plot_confs:
|
||||
best_vs_atc:
|
||||
2gs_vs_atc:
|
||||
PLOT_ESTIMATORS:
|
||||
- bin_sld_gs
|
||||
- bin_sld_qgs
|
||||
- mul_sld_gs
|
||||
- mul_sld_qgs
|
||||
- ref
|
||||
- atc_mc
|
||||
- atc_ne
|
||||
sld_vs_pacc:
|
||||
PLOT_ESTIMATORS:
|
||||
- bin_sld
|
||||
- bin_sld_bcts
|
||||
- bin_sld_gs
|
||||
- mul_sld
|
||||
- mul_sld_bcts
|
||||
- mul_sld_gs
|
||||
- ref
|
||||
- atc_mc
|
||||
|
@ -102,4 +107,4 @@ main_conf: &main_conf
|
|||
- atc_ne
|
||||
- doc_feat
|
||||
|
||||
exec: *debug_conf
|
||||
exec: *test_conf
|
94
quacc.log
94
quacc.log
|
@ -1494,3 +1494,97 @@
|
|||
01/11/23 13:07:27| INFO Dataset sample 0.50 of dataset imdb_1prevs started
|
||||
01/11/23 13:07:27| ERROR Evaluation over imdb_1prevs failed. Exception: 'Invalid estimator: estimator mul_sld_gs does not exist'
|
||||
01/11/23 13:07:27| ERROR Failed while saving configuration imdb_debug of imdb_1prevs. Exception: cannot access local variable 'dr' where it is not associated with a value
|
||||
----------------------------------------------------------------------------------------------------
|
||||
03/11/23 20:54:19| INFO dataset rcv1_CCAT_9prevs
|
||||
03/11/23 20:54:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 20:54:28| WARNING Method mul_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib'].
|
||||
03/11/23 20:54:29| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor'
|
||||
03/11/23 20:54:30| WARNING Method bin_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib'].
|
||||
03/11/23 20:55:09| INFO ref finished [took 38.5179s]
|
||||
----------------------------------------------------------------------------------------------------
|
||||
03/11/23 21:28:36| INFO dataset rcv1_CCAT_9prevs
|
||||
03/11/23 21:28:41| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 21:28:45| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor'
|
||||
----------------------------------------------------------------------------------------------------
|
||||
03/11/23 21:31:03| INFO dataset rcv1_CCAT_9prevs
|
||||
03/11/23 21:31:08| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 21:31:59| INFO ref finished [took 45.6616s]
|
||||
03/11/23 21:32:03| INFO atc_mc finished [took 48.4360s]
|
||||
03/11/23 21:32:07| INFO atc_ne finished [took 51.0515s]
|
||||
03/11/23 21:32:23| INFO mul_sld finished [took 72.9229s]
|
||||
03/11/23 21:34:43| INFO bin_sld finished [took 213.9538s]
|
||||
03/11/23 21:36:27| INFO mul_sld_gs finished [took 314.9357s]
|
||||
03/11/23 21:40:50| INFO bin_sld_gs finished [took 579.2530s]
|
||||
03/11/23 21:40:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 582.5876s]
|
||||
03/11/23 21:40:50| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 21:41:39| INFO ref finished [took 43.7409s]
|
||||
03/11/23 21:41:43| INFO atc_mc finished [took 46.4580s]
|
||||
03/11/23 21:41:44| INFO atc_ne finished [took 46.4267s]
|
||||
03/11/23 21:41:54| INFO mul_sld finished [took 61.3005s]
|
||||
03/11/23 21:44:18| INFO bin_sld finished [took 206.3680s]
|
||||
03/11/23 21:45:59| INFO mul_sld_gs finished [took 304.4726s]
|
||||
03/11/23 21:50:33| INFO bin_sld_gs finished [took 579.3455s]
|
||||
03/11/23 21:50:33| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 582.4808s]
|
||||
03/11/23 21:50:33| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 21:51:22| INFO ref finished [took 43.6853s]
|
||||
03/11/23 21:51:26| INFO atc_mc finished [took 47.1366s]
|
||||
03/11/23 21:51:30| INFO atc_ne finished [took 49.4868s]
|
||||
03/11/23 21:51:34| INFO mul_sld finished [took 59.0964s]
|
||||
03/11/23 21:53:59| INFO bin_sld finished [took 205.0248s]
|
||||
03/11/23 21:55:50| INFO mul_sld_gs finished [took 312.5630s]
|
||||
03/11/23 22:00:27| INFO bin_sld_gs finished [took 591.1460s]
|
||||
03/11/23 22:00:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 594.3163s]
|
||||
03/11/23 22:00:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 22:01:15| INFO ref finished [took 43.3806s]
|
||||
03/11/23 22:01:19| INFO atc_mc finished [took 46.6674s]
|
||||
03/11/23 22:01:21| INFO atc_ne finished [took 47.1220s]
|
||||
03/11/23 22:01:28| INFO mul_sld finished [took 58.6799s]
|
||||
03/11/23 22:03:53| INFO bin_sld finished [took 204.7659s]
|
||||
03/11/23 22:05:39| INFO mul_sld_gs finished [took 307.8811s]
|
||||
03/11/23 22:10:32| INFO bin_sld_gs finished [took 601.9995s]
|
||||
03/11/23 22:10:32| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 604.8406s]
|
||||
03/11/23 22:10:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 22:11:20| INFO ref finished [took 42.8256s]
|
||||
03/11/23 22:11:25| INFO atc_mc finished [took 46.9203s]
|
||||
03/11/23 22:11:28| INFO atc_ne finished [took 49.3042s]
|
||||
03/11/23 22:11:34| INFO mul_sld finished [took 60.2744s]
|
||||
03/11/23 22:13:59| INFO bin_sld finished [took 205.7078s]
|
||||
03/11/23 22:15:45| INFO mul_sld_gs finished [took 309.0888s]
|
||||
03/11/23 22:20:32| INFO bin_sld_gs finished [took 596.5102s]
|
||||
03/11/23 22:20:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 599.5067s]
|
||||
03/11/23 22:20:32| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 22:21:20| INFO ref finished [took 43.1698s]
|
||||
03/11/23 22:21:24| INFO atc_mc finished [took 46.5768s]
|
||||
03/11/23 22:21:25| INFO atc_ne finished [took 46.3408s]
|
||||
03/11/23 22:21:34| INFO mul_sld finished [took 60.8070s]
|
||||
03/11/23 22:23:58| INFO bin_sld finished [took 205.3362s]
|
||||
03/11/23 22:25:44| INFO mul_sld_gs finished [took 308.1859s]
|
||||
03/11/23 22:30:44| INFO bin_sld_gs finished [took 609.5468s]
|
||||
03/11/23 22:30:44| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 612.5803s]
|
||||
03/11/23 22:30:44| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 22:31:32| INFO ref finished [took 43.2949s]
|
||||
03/11/23 22:31:37| INFO atc_mc finished [took 46.3686s]
|
||||
03/11/23 22:31:40| INFO atc_ne finished [took 49.2242s]
|
||||
03/11/23 22:31:47| INFO mul_sld finished [took 60.9437s]
|
||||
03/11/23 22:34:11| INFO bin_sld finished [took 205.9299s]
|
||||
03/11/23 22:35:56| INFO mul_sld_gs finished [took 308.2738s]
|
||||
03/11/23 22:40:36| INFO bin_sld_gs finished [took 588.7918s]
|
||||
03/11/23 22:40:36| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 591.8830s]
|
||||
03/11/23 22:40:36| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 22:41:24| INFO ref finished [took 43.3321s]
|
||||
03/11/23 22:41:29| INFO atc_mc finished [took 46.8041s]
|
||||
03/11/23 22:41:29| INFO atc_ne finished [took 46.5810s]
|
||||
03/11/23 22:41:38| INFO mul_sld finished [took 60.2962s]
|
||||
03/11/23 22:44:07| INFO bin_sld finished [took 209.6435s]
|
||||
03/11/23 22:45:44| INFO mul_sld_gs finished [took 304.4809s]
|
||||
03/11/23 22:50:39| INFO bin_sld_gs finished [took 599.5588s]
|
||||
03/11/23 22:50:39| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 602.5720s]
|
||||
03/11/23 22:50:39| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
|
||||
03/11/23 22:51:26| INFO ref finished [took 42.4313s]
|
||||
03/11/23 22:51:30| INFO atc_mc finished [took 45.5261s]
|
||||
03/11/23 22:51:34| INFO atc_ne finished [took 48.4488s]
|
||||
03/11/23 22:51:47| INFO mul_sld finished [took 66.4801s]
|
||||
03/11/23 22:54:08| INFO bin_sld finished [took 208.4272s]
|
||||
03/11/23 22:55:49| INFO mul_sld_gs finished [took 306.4505s]
|
||||
03/11/23 23:00:15| INFO bin_sld_gs finished [took 573.7761s]
|
||||
03/11/23 23:00:15| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 576.7586s]
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||
|
||||
import quacc as qc
|
||||
|
||||
from ..method.base import BaseAccuracyEstimator
|
||||
|
||||
|
||||
def evaluate(
|
||||
estimator: BaseAccuracyEstimator,
|
||||
protocol: AbstractProtocol,
|
||||
error_metric: Union[Callable | str],
|
||||
) -> float:
|
||||
if isinstance(error_metric, str):
|
||||
error_metric = qc.error.from_name(error_metric)
|
||||
|
||||
collator_bck_ = protocol.collator
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
estim_prevs, true_prevs = [], []
|
||||
for sample in protocol():
|
||||
e_sample = estimator.extend(sample)
|
||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||
estim_prevs.append(estim_prev)
|
||||
true_prevs.append(e_sample.prevalence())
|
||||
|
||||
protocol.collator = collator_bck_
|
||||
|
||||
true_prevs = np.array(true_prevs)
|
||||
estim_prevs = np.array(estim_prevs)
|
||||
|
||||
return error_metric(true_prevs, estim_prevs)
|
|
@ -1,9 +1,9 @@
|
|||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
from quapy.method.aggregative import SLD
|
||||
from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
|
||||
from quapy.protocol import UPP, AbstractProtocol
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
import quacc as qc
|
||||
|
@ -25,38 +25,12 @@ def method(func):
|
|||
return wrapper
|
||||
|
||||
|
||||
def evaluate(
|
||||
estimator: BaseAccuracyEstimator,
|
||||
protocol: AbstractProtocol,
|
||||
error_metric: Union[Callable | str],
|
||||
) -> float:
|
||||
if isinstance(error_metric, str):
|
||||
error_metric = qc.error.from_name(error_metric)
|
||||
|
||||
collator_bck_ = protocol.collator
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
estim_prevs, true_prevs = [], []
|
||||
for sample in protocol():
|
||||
e_sample = estimator.extend(sample)
|
||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||
estim_prevs.append(estim_prev)
|
||||
true_prevs.append(e_sample.prevalence())
|
||||
|
||||
protocol.collator = collator_bck_
|
||||
|
||||
true_prevs = np.array(true_prevs)
|
||||
estim_prevs = np.array(estim_prevs)
|
||||
|
||||
return error_metric(true_prevs, estim_prevs)
|
||||
|
||||
|
||||
def evaluation_report(
|
||||
estimator: BaseAccuracyEstimator,
|
||||
protocol: AbstractProtocol,
|
||||
method: str,
|
||||
) -> EvaluationReport:
|
||||
report = EvaluationReport(name=method)
|
||||
method_name = inspect.stack()[1].function
|
||||
report = EvaluationReport(name=method_name)
|
||||
for sample in protocol():
|
||||
e_sample = estimator.extend(sample)
|
||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||
|
@ -80,7 +54,6 @@ def bin_sld(c_model, validation, protocol) -> EvaluationReport:
|
|||
return evaluation_report(
|
||||
estimator=est,
|
||||
protocol=protocol,
|
||||
method="bin_sld",
|
||||
)
|
||||
|
||||
|
||||
|
@ -90,8 +63,7 @@ def mul_sld(c_model, validation, protocol) -> EvaluationReport:
|
|||
est.fit(validation)
|
||||
return evaluation_report(
|
||||
estimator=est,
|
||||
protocor=protocol,
|
||||
method="mul_sld",
|
||||
protocol=protocol,
|
||||
)
|
||||
|
||||
|
||||
|
@ -102,7 +74,6 @@ def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
|||
return evaluation_report(
|
||||
estimator=est,
|
||||
protocol=protocol,
|
||||
method="bin_sld_bcts",
|
||||
)
|
||||
|
||||
|
||||
|
@ -113,14 +84,13 @@ def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
|||
return evaluation_report(
|
||||
estimator=est,
|
||||
protocol=protocol,
|
||||
method="mul_sld_bcts",
|
||||
)
|
||||
|
||||
|
||||
@method
|
||||
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||
model = SLD(LogisticRegression())
|
||||
model = BQAE(c_model, SLD(LogisticRegression()))
|
||||
est = GridSearchAE(
|
||||
model=model,
|
||||
param_grid={
|
||||
|
@ -130,10 +100,30 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
|||
},
|
||||
refit=False,
|
||||
protocol=UPP(v_val, repeats=100),
|
||||
verbose=True,
|
||||
verbose=False,
|
||||
).fit(v_train)
|
||||
return evaluation_report(
|
||||
estimator=est,
|
||||
protocol=protocol,
|
||||
)
|
||||
|
||||
|
||||
@method
|
||||
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||
model = MCAE(c_model, SLD(LogisticRegression()))
|
||||
est = GridSearchAE(
|
||||
model=model,
|
||||
param_grid={
|
||||
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||
"q__classifier__class_weight": [None, "balanced"],
|
||||
"q__recalib": [None, "bcts", "vs"],
|
||||
},
|
||||
refit=False,
|
||||
protocol=UPP(v_val, repeats=100),
|
||||
verbose=False,
|
||||
).fit(v_train)
|
||||
return evaluation_report(
|
||||
estimator=est,
|
||||
protocol=protocol,
|
||||
method="mul_sld_gs",
|
||||
)
|
||||
|
|
|
@ -10,7 +10,7 @@ from sklearn.linear_model import LogisticRegression
|
|||
from quacc.dataset import Dataset
|
||||
from quacc.error import acc
|
||||
from quacc.evaluation.report import CompReport, EvaluationReport
|
||||
from quacc.method.base import MultiClassAccuracyEstimator
|
||||
from quacc.method.base import BinaryQuantifierAccuracyEstimator
|
||||
from quacc.method.model_selection import GridSearchAE
|
||||
|
||||
|
||||
|
@ -21,8 +21,8 @@ def test_gs():
|
|||
classifier.fit(*d.train.Xy)
|
||||
|
||||
quantifier = SLD(LogisticRegression())
|
||||
estimator = MultiClassAccuracyEstimator(classifier, quantifier)
|
||||
estimator.fit(d.validation)
|
||||
# estimator = MultiClassAccuracyEstimator(classifier, quantifier)
|
||||
estimator = BinaryQuantifierAccuracyEstimator(classifier, quantifier)
|
||||
|
||||
v_train, v_val = d.validation.split_stratified(0.6, random_state=0)
|
||||
gs_protocol = UPP(v_val, sample_size=1000, repeats=100)
|
||||
|
@ -31,13 +31,15 @@ def test_gs():
|
|||
param_grid={
|
||||
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||
"q__classifier__class_weight": [None, "balanced"],
|
||||
"q__recalib": [None, "bcts", "vs"],
|
||||
"q__recalib": [None, "bcts", "ts"],
|
||||
},
|
||||
refit=False,
|
||||
protocol=gs_protocol,
|
||||
verbose=True,
|
||||
).fit(v_train)
|
||||
|
||||
estimator.fit(d.validation)
|
||||
|
||||
tstart = time()
|
||||
erb, ergs = EvaluationReport("base"), EvaluationReport("gs")
|
||||
protocol = APP(
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -1,15 +1,13 @@
|
|||
import math
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import quapy as qp
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import CC, SLD, BaseQuantifier
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.protocol import UPP
|
||||
from quapy.method.aggregative import BaseQuantifier
|
||||
from scipy.sparse import csr_matrix
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import cross_val_predict
|
||||
|
||||
from quacc.data import ExtendedCollection
|
||||
|
||||
|
@ -20,9 +18,7 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
|||
classifier: BaseEstimator,
|
||||
quantifier: BaseQuantifier,
|
||||
):
|
||||
self.fit_score = None
|
||||
self.__check_classifier(classifier)
|
||||
self.classifier = classifier
|
||||
self.quantifier = quantifier
|
||||
|
||||
def __check_classifier(self, classifier):
|
||||
|
@ -30,21 +26,7 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
|||
raise ValueError(
|
||||
f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities."
|
||||
)
|
||||
|
||||
def _gs_params(self, t_val: LabelledCollection):
|
||||
return {
|
||||
"param_grid": {
|
||||
"classifier__C": np.logspace(-3, 3, 7),
|
||||
"classifier__class_weight": [None, "balanced"],
|
||||
"recalib": [None, "bcts"],
|
||||
},
|
||||
"protocol": UPP(t_val, repeats=1000),
|
||||
"error": qp.error.mae,
|
||||
"refit": False,
|
||||
"timeout": -1,
|
||||
"n_jobs": None,
|
||||
"verbose": True,
|
||||
}
|
||||
self.classifier = classifier
|
||||
|
||||
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||
if not pred_proba:
|
||||
|
@ -67,6 +49,7 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
|||
quantifier: BaseQuantifier,
|
||||
):
|
||||
super().__init__(classifier, quantifier)
|
||||
self.e_train = None
|
||||
|
||||
def fit(self, train: LabelledCollection):
|
||||
pred_probs = self.classifier.predict_proba(train.X)
|
||||
|
@ -95,84 +78,52 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
|||
|
||||
|
||||
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
|
||||
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
|
||||
super().__init__()
|
||||
self.c_model = c_model
|
||||
self._q_model_name = q_model.upper()
|
||||
self.q_models = []
|
||||
self.gs = gs
|
||||
self.recalib = recalib
|
||||
self.e_train = None
|
||||
def __init__(self, classifier: BaseEstimator, quantifier: BaseAccuracyEstimator):
|
||||
super().__init__(classifier, quantifier)
|
||||
self.quantifiers = []
|
||||
self.e_trains = []
|
||||
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
# check if model is fit
|
||||
# self.model.fit(*train.Xy)
|
||||
if isinstance(train, LabelledCollection):
|
||||
pred_prob_train = cross_val_predict(
|
||||
self.c_model, *train.Xy, method="predict_proba"
|
||||
)
|
||||
|
||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
||||
elif isinstance(train, ExtendedCollection):
|
||||
self.e_train = train
|
||||
pred_probs = self.classifier.predict_proba(train.X)
|
||||
self.e_train = ExtendedCollection.extend_collection(train, pred_probs)
|
||||
|
||||
self.n_classes = self.e_train.n_classes
|
||||
e_trains = self.e_train.split_by_pred()
|
||||
self.e_trains = self.e_train.split_by_pred()
|
||||
self.quantifiers = [deepcopy(self.quantifier) for _ in self.e_trains]
|
||||
|
||||
if self._q_model_name == "SLD":
|
||||
fit_scores = []
|
||||
for e_train in e_trains:
|
||||
if self.gs:
|
||||
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
|
||||
gs_params = self._gs_params(t_val)
|
||||
q_model = GridSearchQ(
|
||||
SLD(LogisticRegression()),
|
||||
**gs_params,
|
||||
)
|
||||
q_model.fit(t_train)
|
||||
fit_scores.append(q_model.best_score_)
|
||||
self.q_models.append(q_model)
|
||||
else:
|
||||
q_model = SLD(LogisticRegression(), recalib=self.recalib)
|
||||
q_model.fit(e_train)
|
||||
self.q_models.append(q_model)
|
||||
|
||||
if self.gs:
|
||||
self.fit_score = np.mean(fit_scores)
|
||||
|
||||
elif self._q_model_name == "CC":
|
||||
for e_train in e_trains:
|
||||
q_model = CC(LogisticRegression())
|
||||
q_model.fit(e_train)
|
||||
self.q_models.append(q_model)
|
||||
self.quantifiers = []
|
||||
for train in self.e_trains:
|
||||
quant = deepcopy(self.quantifier)
|
||||
quant.fit(train)
|
||||
self.quantifiers.append(quant)
|
||||
|
||||
def estimate(self, instances, ext=False):
|
||||
# TODO: test
|
||||
e_inst = instances
|
||||
if not ext:
|
||||
pred_prob = self.c_model.predict_proba(instances)
|
||||
pred_prob = self.classifier.predict_proba(instances)
|
||||
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
||||
else:
|
||||
e_inst = instances
|
||||
|
||||
_ncl = int(math.sqrt(self.n_classes))
|
||||
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
|
||||
estim_prevs = [
|
||||
self._quantify_helper(inst, norm, q_model)
|
||||
for (inst, norm, q_model) in zip(s_inst, norms, self.q_models)
|
||||
]
|
||||
estim_prevs = self._quantify_helper(s_inst, norms)
|
||||
|
||||
estim_prev = []
|
||||
for prev_row in zip(*estim_prevs):
|
||||
for prev in prev_row:
|
||||
estim_prev.append(prev)
|
||||
estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten()
|
||||
return estim_prev
|
||||
|
||||
return np.asarray(estim_prev)
|
||||
def _quantify_helper(
|
||||
self,
|
||||
s_inst: List[np.ndarray | csr_matrix],
|
||||
norms: List[float],
|
||||
):
|
||||
estim_prevs = []
|
||||
for quant, inst, norm in zip(self.quantifiers, s_inst, norms):
|
||||
if inst.shape[0] > 0:
|
||||
estim_prevs.append(quant.quantify(inst) * norm)
|
||||
else:
|
||||
estim_prevs.append(np.asarray([0.0, 0.0]))
|
||||
|
||||
def _quantify_helper(self, inst, norm, q_model):
|
||||
if inst.shape[0] > 0:
|
||||
return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst))))
|
||||
else:
|
||||
return np.asarray([0.0, 0.0])
|
||||
return estim_prevs
|
||||
|
||||
|
||||
BAE = BaseAccuracyEstimator
|
||||
|
|
|
@ -7,8 +7,9 @@ from quapy.data import LabelledCollection
|
|||
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||
|
||||
import quacc as qc
|
||||
import quacc.evaluation.method as evaluation
|
||||
import quacc.error
|
||||
from quacc.data import ExtendedCollection
|
||||
from quacc.evaluation import evaluate
|
||||
from quacc.method.base import BaseAccuracyEstimator
|
||||
|
||||
|
||||
|
@ -138,8 +139,9 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
model = deepcopy(self.model)
|
||||
# overrides default parameters with the parameters being explored at this iteration
|
||||
model.set_params(**params)
|
||||
# print({k: v for k, v in model.get_params().items() if k in params})
|
||||
model.fit(training)
|
||||
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
|
||||
score = evaluate(model, protocol=protocol, error_metric=error)
|
||||
|
||||
ttime = time() - tinit
|
||||
self._sout(
|
||||
|
@ -157,7 +159,6 @@ class GridSearchAE(BaseAccuracyEstimator):
|
|||
except Exception as e:
|
||||
self._sout(f"something went wrong for config {params}; skipping:")
|
||||
self._sout(f"\tException: {e}")
|
||||
# traceback(e)
|
||||
score = None
|
||||
|
||||
return params, score, model
|
||||
|
|
Loading…
Reference in New Issue