bin adapted to grid search

This commit is contained in:
Lorenzo Volpi 2023-11-03 23:28:40 +01:00
parent eccd818719
commit d1be2b72e8
10 changed files with 242 additions and 164 deletions

View File

@ -1,14 +1,5 @@
{
"todo": [
{
"assignedTo": {
"name": "Lorenzo Volpi"
},
"creation_time": "2023-10-28T14:34:46.226Z",
"id": "4",
"references": [],
"title": "Aggingere estimator basati su PACC (quantificatore)"
},
{
"assignedTo": {
"name": "Lorenzo Volpi"
@ -18,15 +9,6 @@
"references": [],
"title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence"
},
{
"assignedTo": {
"name": "Lorenzo Volpi"
},
"creation_time": "2023-10-28T14:34:23.217Z",
"id": "3",
"references": [],
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
},
{
"assignedTo": {
"name": "Lorenzo Volpi"
@ -38,6 +20,27 @@
}
],
"in-progress": [
{
"assignedTo": {
"name": "Lorenzo Volpi"
},
"creation_time": "2023-10-28T14:34:23.217Z",
"id": "3",
"references": [],
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
},
{
"assignedTo": {
"name": "Lorenzo Volpi"
},
"creation_time": "2023-10-28T14:34:46.226Z",
"id": "4",
"references": [],
"title": "Aggingere estimator basati su PACC (quantificatore)"
}
],
"testing": [],
"done": [
{
"assignedTo": {
"name": "Lorenzo Volpi"
@ -47,7 +50,5 @@
"references": [],
"title": "Rework rappresentazione dati di report"
}
],
"testing": [],
"done": []
]
}

View File

@ -12,8 +12,9 @@ debug_conf: &debug_conf
plot_confs:
debug:
PLOT_ESTIMATORS:
- mul_sld_gs
- mul_sld
- ref
- atc_mc
PLOT_STDEV: true
test_conf: &test_conf
@ -21,24 +22,28 @@ test_conf: &test_conf
METRICS:
- acc
- f1
DATASET_N_PREVS: 2
DATASET_PREVS:
- 0.5
- 0.1
DATASET_N_PREVS: 9
confs:
# - DATASET_NAME: rcv1
# DATASET_TARGET: CCAT
- DATASET_NAME: imdb
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
# - DATASET_NAME: imdb
plot_confs:
best_vs_atc:
2gs_vs_atc:
PLOT_ESTIMATORS:
- bin_sld_gs
- bin_sld_qgs
- mul_sld_gs
- mul_sld_qgs
- ref
- atc_mc
- atc_ne
sld_vs_pacc:
PLOT_ESTIMATORS:
- bin_sld
- bin_sld_bcts
- bin_sld_gs
- mul_sld
- mul_sld_bcts
- mul_sld_gs
- ref
- atc_mc
@ -102,4 +107,4 @@ main_conf: &main_conf
- atc_ne
- doc_feat
exec: *debug_conf
exec: *test_conf

View File

@ -1494,3 +1494,97 @@
01/11/23 13:07:27| INFO Dataset sample 0.50 of dataset imdb_1prevs started
01/11/23 13:07:27| ERROR Evaluation over imdb_1prevs failed. Exception: 'Invalid estimator: estimator mul_sld_gs does not exist'
01/11/23 13:07:27| ERROR Failed while saving configuration imdb_debug of imdb_1prevs. Exception: cannot access local variable 'dr' where it is not associated with a value
----------------------------------------------------------------------------------------------------
03/11/23 20:54:19| INFO dataset rcv1_CCAT_9prevs
03/11/23 20:54:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
03/11/23 20:54:28| WARNING Method mul_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib'].
03/11/23 20:54:29| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor'
03/11/23 20:54:30| WARNING Method bin_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib'].
03/11/23 20:55:09| INFO ref finished [took 38.5179s]
----------------------------------------------------------------------------------------------------
03/11/23 21:28:36| INFO dataset rcv1_CCAT_9prevs
03/11/23 21:28:41| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
03/11/23 21:28:45| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor'
----------------------------------------------------------------------------------------------------
03/11/23 21:31:03| INFO dataset rcv1_CCAT_9prevs
03/11/23 21:31:08| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
03/11/23 21:31:59| INFO ref finished [took 45.6616s]
03/11/23 21:32:03| INFO atc_mc finished [took 48.4360s]
03/11/23 21:32:07| INFO atc_ne finished [took 51.0515s]
03/11/23 21:32:23| INFO mul_sld finished [took 72.9229s]
03/11/23 21:34:43| INFO bin_sld finished [took 213.9538s]
03/11/23 21:36:27| INFO mul_sld_gs finished [took 314.9357s]
03/11/23 21:40:50| INFO bin_sld_gs finished [took 579.2530s]
03/11/23 21:40:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 582.5876s]
03/11/23 21:40:50| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
03/11/23 21:41:39| INFO ref finished [took 43.7409s]
03/11/23 21:41:43| INFO atc_mc finished [took 46.4580s]
03/11/23 21:41:44| INFO atc_ne finished [took 46.4267s]
03/11/23 21:41:54| INFO mul_sld finished [took 61.3005s]
03/11/23 21:44:18| INFO bin_sld finished [took 206.3680s]
03/11/23 21:45:59| INFO mul_sld_gs finished [took 304.4726s]
03/11/23 21:50:33| INFO bin_sld_gs finished [took 579.3455s]
03/11/23 21:50:33| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 582.4808s]
03/11/23 21:50:33| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started
03/11/23 21:51:22| INFO ref finished [took 43.6853s]
03/11/23 21:51:26| INFO atc_mc finished [took 47.1366s]
03/11/23 21:51:30| INFO atc_ne finished [took 49.4868s]
03/11/23 21:51:34| INFO mul_sld finished [took 59.0964s]
03/11/23 21:53:59| INFO bin_sld finished [took 205.0248s]
03/11/23 21:55:50| INFO mul_sld_gs finished [took 312.5630s]
03/11/23 22:00:27| INFO bin_sld_gs finished [took 591.1460s]
03/11/23 22:00:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 594.3163s]
03/11/23 22:00:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
03/11/23 22:01:15| INFO ref finished [took 43.3806s]
03/11/23 22:01:19| INFO atc_mc finished [took 46.6674s]
03/11/23 22:01:21| INFO atc_ne finished [took 47.1220s]
03/11/23 22:01:28| INFO mul_sld finished [took 58.6799s]
03/11/23 22:03:53| INFO bin_sld finished [took 204.7659s]
03/11/23 22:05:39| INFO mul_sld_gs finished [took 307.8811s]
03/11/23 22:10:32| INFO bin_sld_gs finished [took 601.9995s]
03/11/23 22:10:32| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 604.8406s]
03/11/23 22:10:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
03/11/23 22:11:20| INFO ref finished [took 42.8256s]
03/11/23 22:11:25| INFO atc_mc finished [took 46.9203s]
03/11/23 22:11:28| INFO atc_ne finished [took 49.3042s]
03/11/23 22:11:34| INFO mul_sld finished [took 60.2744s]
03/11/23 22:13:59| INFO bin_sld finished [took 205.7078s]
03/11/23 22:15:45| INFO mul_sld_gs finished [took 309.0888s]
03/11/23 22:20:32| INFO bin_sld_gs finished [took 596.5102s]
03/11/23 22:20:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 599.5067s]
03/11/23 22:20:32| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
03/11/23 22:21:20| INFO ref finished [took 43.1698s]
03/11/23 22:21:24| INFO atc_mc finished [took 46.5768s]
03/11/23 22:21:25| INFO atc_ne finished [took 46.3408s]
03/11/23 22:21:34| INFO mul_sld finished [took 60.8070s]
03/11/23 22:23:58| INFO bin_sld finished [took 205.3362s]
03/11/23 22:25:44| INFO mul_sld_gs finished [took 308.1859s]
03/11/23 22:30:44| INFO bin_sld_gs finished [took 609.5468s]
03/11/23 22:30:44| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 612.5803s]
03/11/23 22:30:44| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
03/11/23 22:31:32| INFO ref finished [took 43.2949s]
03/11/23 22:31:37| INFO atc_mc finished [took 46.3686s]
03/11/23 22:31:40| INFO atc_ne finished [took 49.2242s]
03/11/23 22:31:47| INFO mul_sld finished [took 60.9437s]
03/11/23 22:34:11| INFO bin_sld finished [took 205.9299s]
03/11/23 22:35:56| INFO mul_sld_gs finished [took 308.2738s]
03/11/23 22:40:36| INFO bin_sld_gs finished [took 588.7918s]
03/11/23 22:40:36| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 591.8830s]
03/11/23 22:40:36| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
03/11/23 22:41:24| INFO ref finished [took 43.3321s]
03/11/23 22:41:29| INFO atc_mc finished [took 46.8041s]
03/11/23 22:41:29| INFO atc_ne finished [took 46.5810s]
03/11/23 22:41:38| INFO mul_sld finished [took 60.2962s]
03/11/23 22:44:07| INFO bin_sld finished [took 209.6435s]
03/11/23 22:45:44| INFO mul_sld_gs finished [took 304.4809s]
03/11/23 22:50:39| INFO bin_sld_gs finished [took 599.5588s]
03/11/23 22:50:39| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 602.5720s]
03/11/23 22:50:39| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
03/11/23 22:51:26| INFO ref finished [took 42.4313s]
03/11/23 22:51:30| INFO atc_mc finished [took 45.5261s]
03/11/23 22:51:34| INFO atc_ne finished [took 48.4488s]
03/11/23 22:51:47| INFO mul_sld finished [took 66.4801s]
03/11/23 22:54:08| INFO bin_sld finished [took 208.4272s]
03/11/23 22:55:49| INFO mul_sld_gs finished [took 306.4505s]
03/11/23 23:00:15| INFO bin_sld_gs finished [took 573.7761s]
03/11/23 23:00:15| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 576.7586s]

View File

@ -0,0 +1,34 @@
from typing import Callable, Union
import numpy as np
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
import quacc as qc
from ..method.base import BaseAccuracyEstimator
def evaluate(
estimator: BaseAccuracyEstimator,
protocol: AbstractProtocol,
error_metric: Union[Callable | str],
) -> float:
if isinstance(error_metric, str):
error_metric = qc.error.from_name(error_metric)
collator_bck_ = protocol.collator
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
estim_prevs, true_prevs = [], []
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
estim_prevs.append(estim_prev)
true_prevs.append(e_sample.prevalence())
protocol.collator = collator_bck_
true_prevs = np.array(true_prevs)
estim_prevs = np.array(estim_prevs)
return error_metric(true_prevs, estim_prevs)

View File

@ -1,9 +1,9 @@
import inspect
from functools import wraps
from typing import Callable, Union
import numpy as np
from quapy.method.aggregative import SLD
from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
from quapy.protocol import UPP, AbstractProtocol
from sklearn.linear_model import LogisticRegression
import quacc as qc
@ -25,38 +25,12 @@ def method(func):
return wrapper
def evaluate(
estimator: BaseAccuracyEstimator,
protocol: AbstractProtocol,
error_metric: Union[Callable | str],
) -> float:
if isinstance(error_metric, str):
error_metric = qc.error.from_name(error_metric)
collator_bck_ = protocol.collator
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
estim_prevs, true_prevs = [], []
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
estim_prevs.append(estim_prev)
true_prevs.append(e_sample.prevalence())
protocol.collator = collator_bck_
true_prevs = np.array(true_prevs)
estim_prevs = np.array(estim_prevs)
return error_metric(true_prevs, estim_prevs)
def evaluation_report(
estimator: BaseAccuracyEstimator,
protocol: AbstractProtocol,
method: str,
) -> EvaluationReport:
report = EvaluationReport(name=method)
method_name = inspect.stack()[1].function
report = EvaluationReport(name=method_name)
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
@ -80,7 +54,6 @@ def bin_sld(c_model, validation, protocol) -> EvaluationReport:
return evaluation_report(
estimator=est,
protocol=protocol,
method="bin_sld",
)
@ -90,8 +63,7 @@ def mul_sld(c_model, validation, protocol) -> EvaluationReport:
est.fit(validation)
return evaluation_report(
estimator=est,
protocor=protocol,
method="mul_sld",
protocol=protocol,
)
@ -102,7 +74,6 @@ def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
return evaluation_report(
estimator=est,
protocol=protocol,
method="bin_sld_bcts",
)
@ -113,14 +84,13 @@ def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
return evaluation_report(
estimator=est,
protocol=protocol,
method="mul_sld_bcts",
)
@method
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
v_train, v_val = validation.split_stratified(0.6, random_state=0)
model = SLD(LogisticRegression())
model = BQAE(c_model, SLD(LogisticRegression()))
est = GridSearchAE(
model=model,
param_grid={
@ -130,10 +100,30 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
},
refit=False,
protocol=UPP(v_val, repeats=100),
verbose=True,
verbose=False,
).fit(v_train)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
v_train, v_val = validation.split_stratified(0.6, random_state=0)
model = MCAE(c_model, SLD(LogisticRegression()))
est = GridSearchAE(
model=model,
param_grid={
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts", "vs"],
},
refit=False,
protocol=UPP(v_val, repeats=100),
verbose=False,
).fit(v_train)
return evaluation_report(
estimator=est,
protocol=protocol,
method="mul_sld_gs",
)

View File

@ -10,7 +10,7 @@ from sklearn.linear_model import LogisticRegression
from quacc.dataset import Dataset
from quacc.error import acc
from quacc.evaluation.report import CompReport, EvaluationReport
from quacc.method.base import MultiClassAccuracyEstimator
from quacc.method.base import BinaryQuantifierAccuracyEstimator
from quacc.method.model_selection import GridSearchAE
@ -21,8 +21,8 @@ def test_gs():
classifier.fit(*d.train.Xy)
quantifier = SLD(LogisticRegression())
estimator = MultiClassAccuracyEstimator(classifier, quantifier)
estimator.fit(d.validation)
# estimator = MultiClassAccuracyEstimator(classifier, quantifier)
estimator = BinaryQuantifierAccuracyEstimator(classifier, quantifier)
v_train, v_val = d.validation.split_stratified(0.6, random_state=0)
gs_protocol = UPP(v_val, sample_size=1000, repeats=100)
@ -31,13 +31,15 @@ def test_gs():
param_grid={
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts", "vs"],
"q__recalib": [None, "bcts", "ts"],
},
refit=False,
protocol=gs_protocol,
verbose=True,
).fit(v_train)
estimator.fit(d.validation)
tstart = time()
erb, ergs = EvaluationReport("base"), EvaluationReport("gs")
protocol = APP(

View File

@ -1,15 +1,13 @@
import math
from abc import abstractmethod
from copy import deepcopy
from typing import List
import numpy as np
import quapy as qp
from quapy.data import LabelledCollection
from quapy.method.aggregative import CC, SLD, BaseQuantifier
from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP
from quapy.method.aggregative import BaseQuantifier
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from quacc.data import ExtendedCollection
@ -20,9 +18,7 @@ class BaseAccuracyEstimator(BaseQuantifier):
classifier: BaseEstimator,
quantifier: BaseQuantifier,
):
self.fit_score = None
self.__check_classifier(classifier)
self.classifier = classifier
self.quantifier = quantifier
def __check_classifier(self, classifier):
@ -30,21 +26,7 @@ class BaseAccuracyEstimator(BaseQuantifier):
raise ValueError(
f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities."
)
def _gs_params(self, t_val: LabelledCollection):
return {
"param_grid": {
"classifier__C": np.logspace(-3, 3, 7),
"classifier__class_weight": [None, "balanced"],
"recalib": [None, "bcts"],
},
"protocol": UPP(t_val, repeats=1000),
"error": qp.error.mae,
"refit": False,
"timeout": -1,
"n_jobs": None,
"verbose": True,
}
self.classifier = classifier
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if not pred_proba:
@ -67,6 +49,7 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
quantifier: BaseQuantifier,
):
super().__init__(classifier, quantifier)
self.e_train = None
def fit(self, train: LabelledCollection):
pred_probs = self.classifier.predict_proba(train.X)
@ -95,84 +78,52 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
super().__init__()
self.c_model = c_model
self._q_model_name = q_model.upper()
self.q_models = []
self.gs = gs
self.recalib = recalib
self.e_train = None
def __init__(self, classifier: BaseEstimator, quantifier: BaseAccuracyEstimator):
super().__init__(classifier, quantifier)
self.quantifiers = []
self.e_trains = []
def fit(self, train: LabelledCollection | ExtendedCollection):
# check if model is fit
# self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection):
pred_prob_train = cross_val_predict(
self.c_model, *train.Xy, method="predict_proba"
)
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
elif isinstance(train, ExtendedCollection):
self.e_train = train
pred_probs = self.classifier.predict_proba(train.X)
self.e_train = ExtendedCollection.extend_collection(train, pred_probs)
self.n_classes = self.e_train.n_classes
e_trains = self.e_train.split_by_pred()
self.e_trains = self.e_train.split_by_pred()
self.quantifiers = [deepcopy(self.quantifier) for _ in self.e_trains]
if self._q_model_name == "SLD":
fit_scores = []
for e_train in e_trains:
if self.gs:
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
gs_params = self._gs_params(t_val)
q_model = GridSearchQ(
SLD(LogisticRegression()),
**gs_params,
)
q_model.fit(t_train)
fit_scores.append(q_model.best_score_)
self.q_models.append(q_model)
else:
q_model = SLD(LogisticRegression(), recalib=self.recalib)
q_model.fit(e_train)
self.q_models.append(q_model)
if self.gs:
self.fit_score = np.mean(fit_scores)
elif self._q_model_name == "CC":
for e_train in e_trains:
q_model = CC(LogisticRegression())
q_model.fit(e_train)
self.q_models.append(q_model)
self.quantifiers = []
for train in self.e_trains:
quant = deepcopy(self.quantifier)
quant.fit(train)
self.quantifiers.append(quant)
def estimate(self, instances, ext=False):
# TODO: test
e_inst = instances
if not ext:
pred_prob = self.c_model.predict_proba(instances)
pred_prob = self.classifier.predict_proba(instances)
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
else:
e_inst = instances
_ncl = int(math.sqrt(self.n_classes))
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
estim_prevs = [
self._quantify_helper(inst, norm, q_model)
for (inst, norm, q_model) in zip(s_inst, norms, self.q_models)
]
estim_prevs = self._quantify_helper(s_inst, norms)
estim_prev = []
for prev_row in zip(*estim_prevs):
for prev in prev_row:
estim_prev.append(prev)
estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten()
return estim_prev
return np.asarray(estim_prev)
def _quantify_helper(
self,
s_inst: List[np.ndarray | csr_matrix],
norms: List[float],
):
estim_prevs = []
for quant, inst, norm in zip(self.quantifiers, s_inst, norms):
if inst.shape[0] > 0:
estim_prevs.append(quant.quantify(inst) * norm)
else:
estim_prevs.append(np.asarray([0.0, 0.0]))
def _quantify_helper(self, inst, norm, q_model):
if inst.shape[0] > 0:
return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst))))
else:
return np.asarray([0.0, 0.0])
return estim_prevs
BAE = BaseAccuracyEstimator

View File

@ -7,8 +7,9 @@ from quapy.data import LabelledCollection
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
import quacc as qc
import quacc.evaluation.method as evaluation
import quacc.error
from quacc.data import ExtendedCollection
from quacc.evaluation import evaluate
from quacc.method.base import BaseAccuracyEstimator
@ -138,8 +139,9 @@ class GridSearchAE(BaseAccuracyEstimator):
model = deepcopy(self.model)
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**params)
# print({k: v for k, v in model.get_params().items() if k in params})
model.fit(training)
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
score = evaluate(model, protocol=protocol, error_metric=error)
ttime = time() - tinit
self._sout(
@ -157,7 +159,6 @@ class GridSearchAE(BaseAccuracyEstimator):
except Exception as e:
self._sout(f"something went wrong for config {params}; skipping:")
self._sout(f"\tException: {e}")
# traceback(e)
score = None
return params, score, model