From d1be2b72e8e97bca0bd013f029b630cc3af0b9d1 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Fri, 3 Nov 2023 23:28:40 +0100 Subject: [PATCH] bin adapted to grid search --- .vscode/vscode-kanban.json | 43 ++++--- conf.yaml | 29 +++-- quacc.log | 94 ++++++++++++++ quacc/evaluation/__init__.py | 34 +++++ quacc/evaluation/method.py | 68 +++++----- quacc/main_test.py | 10 +- quacc/method/__pycache__/base.cpython-311.pyc | Bin 10728 -> 8564 bytes .../model_selection.cpython-311.pyc | Bin 10646 -> 10631 bytes quacc/method/base.py | 121 ++++++------------ quacc/method/model_selection.py | 7 +- 10 files changed, 242 insertions(+), 164 deletions(-) diff --git a/.vscode/vscode-kanban.json b/.vscode/vscode-kanban.json index 88a249c..7b6f95c 100644 --- a/.vscode/vscode-kanban.json +++ b/.vscode/vscode-kanban.json @@ -1,14 +1,5 @@ { "todo": [ - { - "assignedTo": { - "name": "Lorenzo Volpi" - }, - "creation_time": "2023-10-28T14:34:46.226Z", - "id": "4", - "references": [], - "title": "Aggingere estimator basati su PACC (quantificatore)" - }, { "assignedTo": { "name": "Lorenzo Volpi" @@ -18,15 +9,6 @@ "references": [], "title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence" }, - { - "assignedTo": { - "name": "Lorenzo Volpi" - }, - "creation_time": "2023-10-28T14:34:23.217Z", - "id": "3", - "references": [], - "title": "Relaizzare grid search per task specifico partedno da GridSearchQ" - }, { "assignedTo": { "name": "Lorenzo Volpi" @@ -38,6 +20,27 @@ } ], "in-progress": [ + { + "assignedTo": { + "name": "Lorenzo Volpi" + }, + "creation_time": "2023-10-28T14:34:23.217Z", + "id": "3", + "references": [], + "title": "Relaizzare grid search per task specifico partedno da GridSearchQ" + }, + { + "assignedTo": { + "name": "Lorenzo Volpi" + }, + "creation_time": "2023-10-28T14:34:46.226Z", + "id": "4", + "references": [], + "title": "Aggingere estimator basati su PACC (quantificatore)" + } + ], + "testing": [], + "done": [ { "assignedTo": { "name": "Lorenzo Volpi" @@ -47,7 +50,5 @@ "references": [], "title": "Rework rappresentazione dati di report" } - ], - "testing": [], - "done": [] + ] } \ No newline at end of file diff --git a/conf.yaml b/conf.yaml index 8cf0b81..a75a718 100644 --- a/conf.yaml +++ b/conf.yaml @@ -12,8 +12,9 @@ debug_conf: &debug_conf plot_confs: debug: PLOT_ESTIMATORS: - - mul_sld_gs + - mul_sld - ref + - atc_mc PLOT_STDEV: true test_conf: &test_conf @@ -21,24 +22,28 @@ test_conf: &test_conf METRICS: - acc - f1 - DATASET_N_PREVS: 2 - DATASET_PREVS: - - 0.5 - - 0.1 + DATASET_N_PREVS: 9 confs: - # - DATASET_NAME: rcv1 - # DATASET_TARGET: CCAT - - DATASET_NAME: imdb + - DATASET_NAME: rcv1 + DATASET_TARGET: CCAT + # - DATASET_NAME: imdb plot_confs: - best_vs_atc: + 2gs_vs_atc: + PLOT_ESTIMATORS: + - bin_sld_gs + - bin_sld_qgs + - mul_sld_gs + - mul_sld_qgs + - ref + - atc_mc + - atc_ne + sld_vs_pacc: PLOT_ESTIMATORS: - bin_sld - - bin_sld_bcts - bin_sld_gs - mul_sld - - mul_sld_bcts - mul_sld_gs - ref - atc_mc @@ -102,4 +107,4 @@ main_conf: &main_conf - atc_ne - doc_feat -exec: *debug_conf \ No newline at end of file +exec: *test_conf \ No newline at end of file diff --git a/quacc.log b/quacc.log index ffe98ee..df47722 100644 --- a/quacc.log +++ b/quacc.log @@ -1494,3 +1494,97 @@ 01/11/23 13:07:27| INFO Dataset sample 0.50 of dataset imdb_1prevs started 01/11/23 13:07:27| ERROR Evaluation over imdb_1prevs failed. Exception: 'Invalid estimator: estimator mul_sld_gs does not exist' 01/11/23 13:07:27| ERROR Failed while saving configuration imdb_debug of imdb_1prevs. Exception: cannot access local variable 'dr' where it is not associated with a value +---------------------------------------------------------------------------------------------------- +03/11/23 20:54:19| INFO dataset rcv1_CCAT_9prevs +03/11/23 20:54:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +03/11/23 20:54:28| WARNING Method mul_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib']. +03/11/23 20:54:29| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor' +03/11/23 20:54:30| WARNING Method bin_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib']. +03/11/23 20:55:09| INFO ref finished [took 38.5179s] +---------------------------------------------------------------------------------------------------- +03/11/23 21:28:36| INFO dataset rcv1_CCAT_9prevs +03/11/23 21:28:41| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +03/11/23 21:28:45| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor' +---------------------------------------------------------------------------------------------------- +03/11/23 21:31:03| INFO dataset rcv1_CCAT_9prevs +03/11/23 21:31:08| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +03/11/23 21:31:59| INFO ref finished [took 45.6616s] +03/11/23 21:32:03| INFO atc_mc finished [took 48.4360s] +03/11/23 21:32:07| INFO atc_ne finished [took 51.0515s] +03/11/23 21:32:23| INFO mul_sld finished [took 72.9229s] +03/11/23 21:34:43| INFO bin_sld finished [took 213.9538s] +03/11/23 21:36:27| INFO mul_sld_gs finished [took 314.9357s] +03/11/23 21:40:50| INFO bin_sld_gs finished [took 579.2530s] +03/11/23 21:40:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 582.5876s] +03/11/23 21:40:50| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started +03/11/23 21:41:39| INFO ref finished [took 43.7409s] +03/11/23 21:41:43| INFO atc_mc finished [took 46.4580s] +03/11/23 21:41:44| INFO atc_ne finished [took 46.4267s] +03/11/23 21:41:54| INFO mul_sld finished [took 61.3005s] +03/11/23 21:44:18| INFO bin_sld finished [took 206.3680s] +03/11/23 21:45:59| INFO mul_sld_gs finished [took 304.4726s] +03/11/23 21:50:33| INFO bin_sld_gs finished [took 579.3455s] +03/11/23 21:50:33| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 582.4808s] +03/11/23 21:50:33| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started +03/11/23 21:51:22| INFO ref finished [took 43.6853s] +03/11/23 21:51:26| INFO atc_mc finished [took 47.1366s] +03/11/23 21:51:30| INFO atc_ne finished [took 49.4868s] +03/11/23 21:51:34| INFO mul_sld finished [took 59.0964s] +03/11/23 21:53:59| INFO bin_sld finished [took 205.0248s] +03/11/23 21:55:50| INFO mul_sld_gs finished [took 312.5630s] +03/11/23 22:00:27| INFO bin_sld_gs finished [took 591.1460s] +03/11/23 22:00:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 594.3163s] +03/11/23 22:00:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started +03/11/23 22:01:15| INFO ref finished [took 43.3806s] +03/11/23 22:01:19| INFO atc_mc finished [took 46.6674s] +03/11/23 22:01:21| INFO atc_ne finished [took 47.1220s] +03/11/23 22:01:28| INFO mul_sld finished [took 58.6799s] +03/11/23 22:03:53| INFO bin_sld finished [took 204.7659s] +03/11/23 22:05:39| INFO mul_sld_gs finished [took 307.8811s] +03/11/23 22:10:32| INFO bin_sld_gs finished [took 601.9995s] +03/11/23 22:10:32| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 604.8406s] +03/11/23 22:10:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started +03/11/23 22:11:20| INFO ref finished [took 42.8256s] +03/11/23 22:11:25| INFO atc_mc finished [took 46.9203s] +03/11/23 22:11:28| INFO atc_ne finished [took 49.3042s] +03/11/23 22:11:34| INFO mul_sld finished [took 60.2744s] +03/11/23 22:13:59| INFO bin_sld finished [took 205.7078s] +03/11/23 22:15:45| INFO mul_sld_gs finished [took 309.0888s] +03/11/23 22:20:32| INFO bin_sld_gs finished [took 596.5102s] +03/11/23 22:20:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 599.5067s] +03/11/23 22:20:32| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started +03/11/23 22:21:20| INFO ref finished [took 43.1698s] +03/11/23 22:21:24| INFO atc_mc finished [took 46.5768s] +03/11/23 22:21:25| INFO atc_ne finished [took 46.3408s] +03/11/23 22:21:34| INFO mul_sld finished [took 60.8070s] +03/11/23 22:23:58| INFO bin_sld finished [took 205.3362s] +03/11/23 22:25:44| INFO mul_sld_gs finished [took 308.1859s] +03/11/23 22:30:44| INFO bin_sld_gs finished [took 609.5468s] +03/11/23 22:30:44| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 612.5803s] +03/11/23 22:30:44| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started +03/11/23 22:31:32| INFO ref finished [took 43.2949s] +03/11/23 22:31:37| INFO atc_mc finished [took 46.3686s] +03/11/23 22:31:40| INFO atc_ne finished [took 49.2242s] +03/11/23 22:31:47| INFO mul_sld finished [took 60.9437s] +03/11/23 22:34:11| INFO bin_sld finished [took 205.9299s] +03/11/23 22:35:56| INFO mul_sld_gs finished [took 308.2738s] +03/11/23 22:40:36| INFO bin_sld_gs finished [took 588.7918s] +03/11/23 22:40:36| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 591.8830s] +03/11/23 22:40:36| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started +03/11/23 22:41:24| INFO ref finished [took 43.3321s] +03/11/23 22:41:29| INFO atc_mc finished [took 46.8041s] +03/11/23 22:41:29| INFO atc_ne finished [took 46.5810s] +03/11/23 22:41:38| INFO mul_sld finished [took 60.2962s] +03/11/23 22:44:07| INFO bin_sld finished [took 209.6435s] +03/11/23 22:45:44| INFO mul_sld_gs finished [took 304.4809s] +03/11/23 22:50:39| INFO bin_sld_gs finished [took 599.5588s] +03/11/23 22:50:39| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 602.5720s] +03/11/23 22:50:39| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started +03/11/23 22:51:26| INFO ref finished [took 42.4313s] +03/11/23 22:51:30| INFO atc_mc finished [took 45.5261s] +03/11/23 22:51:34| INFO atc_ne finished [took 48.4488s] +03/11/23 22:51:47| INFO mul_sld finished [took 66.4801s] +03/11/23 22:54:08| INFO bin_sld finished [took 208.4272s] +03/11/23 22:55:49| INFO mul_sld_gs finished [took 306.4505s] +03/11/23 23:00:15| INFO bin_sld_gs finished [took 573.7761s] +03/11/23 23:00:15| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 576.7586s] diff --git a/quacc/evaluation/__init__.py b/quacc/evaluation/__init__.py index e69de29..1851c4b 100644 --- a/quacc/evaluation/__init__.py +++ b/quacc/evaluation/__init__.py @@ -0,0 +1,34 @@ +from typing import Callable, Union + +import numpy as np +from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol + +import quacc as qc + +from ..method.base import BaseAccuracyEstimator + + +def evaluate( + estimator: BaseAccuracyEstimator, + protocol: AbstractProtocol, + error_metric: Union[Callable | str], +) -> float: + if isinstance(error_metric, str): + error_metric = qc.error.from_name(error_metric) + + collator_bck_ = protocol.collator + protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") + + estim_prevs, true_prevs = [], [] + for sample in protocol(): + e_sample = estimator.extend(sample) + estim_prev = estimator.estimate(e_sample.X, ext=True) + estim_prevs.append(estim_prev) + true_prevs.append(e_sample.prevalence()) + + protocol.collator = collator_bck_ + + true_prevs = np.array(true_prevs) + estim_prevs = np.array(estim_prevs) + + return error_metric(true_prevs, estim_prevs) diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index f08bd0b..d50ccab 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -1,9 +1,9 @@ +import inspect from functools import wraps -from typing import Callable, Union import numpy as np from quapy.method.aggregative import SLD -from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol +from quapy.protocol import UPP, AbstractProtocol from sklearn.linear_model import LogisticRegression import quacc as qc @@ -25,38 +25,12 @@ def method(func): return wrapper -def evaluate( - estimator: BaseAccuracyEstimator, - protocol: AbstractProtocol, - error_metric: Union[Callable | str], -) -> float: - if isinstance(error_metric, str): - error_metric = qc.error.from_name(error_metric) - - collator_bck_ = protocol.collator - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - - estim_prevs, true_prevs = [], [] - for sample in protocol(): - e_sample = estimator.extend(sample) - estim_prev = estimator.estimate(e_sample.X, ext=True) - estim_prevs.append(estim_prev) - true_prevs.append(e_sample.prevalence()) - - protocol.collator = collator_bck_ - - true_prevs = np.array(true_prevs) - estim_prevs = np.array(estim_prevs) - - return error_metric(true_prevs, estim_prevs) - - def evaluation_report( estimator: BaseAccuracyEstimator, protocol: AbstractProtocol, - method: str, ) -> EvaluationReport: - report = EvaluationReport(name=method) + method_name = inspect.stack()[1].function + report = EvaluationReport(name=method_name) for sample in protocol(): e_sample = estimator.extend(sample) estim_prev = estimator.estimate(e_sample.X, ext=True) @@ -80,7 +54,6 @@ def bin_sld(c_model, validation, protocol) -> EvaluationReport: return evaluation_report( estimator=est, protocol=protocol, - method="bin_sld", ) @@ -90,8 +63,7 @@ def mul_sld(c_model, validation, protocol) -> EvaluationReport: est.fit(validation) return evaluation_report( estimator=est, - protocor=protocol, - method="mul_sld", + protocol=protocol, ) @@ -102,7 +74,6 @@ def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport: return evaluation_report( estimator=est, protocol=protocol, - method="bin_sld_bcts", ) @@ -113,14 +84,13 @@ def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport: return evaluation_report( estimator=est, protocol=protocol, - method="mul_sld_bcts", ) @method -def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: +def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport: v_train, v_val = validation.split_stratified(0.6, random_state=0) - model = SLD(LogisticRegression()) + model = BQAE(c_model, SLD(LogisticRegression())) est = GridSearchAE( model=model, param_grid={ @@ -130,10 +100,30 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: }, refit=False, protocol=UPP(v_val, repeats=100), - verbose=True, + verbose=False, + ).fit(v_train) + return evaluation_report( + estimator=est, + protocol=protocol, + ) + + +@method +def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: + v_train, v_val = validation.split_stratified(0.6, random_state=0) + model = MCAE(c_model, SLD(LogisticRegression())) + est = GridSearchAE( + model=model, + param_grid={ + "q__classifier__C": np.logspace(-3, 3, 7), + "q__classifier__class_weight": [None, "balanced"], + "q__recalib": [None, "bcts", "vs"], + }, + refit=False, + protocol=UPP(v_val, repeats=100), + verbose=False, ).fit(v_train) return evaluation_report( estimator=est, protocol=protocol, - method="mul_sld_gs", ) diff --git a/quacc/main_test.py b/quacc/main_test.py index 7239908..ac8a9bd 100644 --- a/quacc/main_test.py +++ b/quacc/main_test.py @@ -10,7 +10,7 @@ from sklearn.linear_model import LogisticRegression from quacc.dataset import Dataset from quacc.error import acc from quacc.evaluation.report import CompReport, EvaluationReport -from quacc.method.base import MultiClassAccuracyEstimator +from quacc.method.base import BinaryQuantifierAccuracyEstimator from quacc.method.model_selection import GridSearchAE @@ -21,8 +21,8 @@ def test_gs(): classifier.fit(*d.train.Xy) quantifier = SLD(LogisticRegression()) - estimator = MultiClassAccuracyEstimator(classifier, quantifier) - estimator.fit(d.validation) + # estimator = MultiClassAccuracyEstimator(classifier, quantifier) + estimator = BinaryQuantifierAccuracyEstimator(classifier, quantifier) v_train, v_val = d.validation.split_stratified(0.6, random_state=0) gs_protocol = UPP(v_val, sample_size=1000, repeats=100) @@ -31,13 +31,15 @@ def test_gs(): param_grid={ "q__classifier__C": np.logspace(-3, 3, 7), "q__classifier__class_weight": [None, "balanced"], - "q__recalib": [None, "bcts", "vs"], + "q__recalib": [None, "bcts", "ts"], }, refit=False, protocol=gs_protocol, verbose=True, ).fit(v_train) + estimator.fit(d.validation) + tstart = time() erb, ergs = EvaluationReport("base"), EvaluationReport("gs") protocol = APP( diff --git a/quacc/method/__pycache__/base.cpython-311.pyc b/quacc/method/__pycache__/base.cpython-311.pyc index 220398abe6bacff0b9d6c5093f0feed9f3899f81..429ef2f960866527a5fc200ed2675818de5776ea 100644 GIT binary patch literal 8564 zcmcgxTWlLwdOkzWkVA?TX-Tw5U2H0{6VtX7*RkTr7um5*J9cdCdL5@JD=$NFMluzO zRL+cCizNYJA)s_2v~J-=1k?)<#nx*#Eue>>&_fI4sV{vYg=s|$AV5H|DEdajDX_>> z|Njg*X}#n0hzCeL?kXv zVtBWuxmjDx#$n!`w$Ji0p5^(pFzbjpSe{EeVa=3#kf+B@4G zYiD^e?VI(-{4DQE@0;z2b#TN+t`W(-M5H#^^UO}j*YGz|tdrGwpw26IZma8Jb?s2+ z+ff&gJ7fFhu2{Dmhy^9?4vF~x8;2Z^aC&=UQc-Clsm{vk{j3Bfq9n_6$?V($`Q!1koYAMb4XsvAv-0f zEJC@sB*a{@TXM4c)NHS6v#&GjcuNx|Tl$N#G_AYYL@DZzbb%5%>+m?Hd*#F4x$ zPvYjYMDCCm+<3&MABpcenRq;TUrv4zZ%&D}!%z{s&MR_yisC|9iGs5Uhr{2WPM(># ztH@NDxS6GLCYKF=noZB8CY0t@WFzX5ywmEFqeUvPO2IL(`VfzuIw zT=&jVSxP0<_#DknCUPA=1yRdVxXIse?m#%1$Yio=xZV`*&~!4DPN^wb8IB0L^L|1} zs4CUnpC;1ta+K06)mYp1;Pk^`+6gzkW2ba;NA}|V zc0CBf2+B1ei)7Qc|8cH#q2hZ_^SxIic31mmM_<)HS&C_gN43M3%Kk~?dTFU(x$+y z0xH;4u;WIWF$D!;vssnk4xy-W5y8;*ZpiE0XB1t$zE5T@CyukOJ1)Dy?S8CwqX-RA zHw?sJy~F62>)V~(M(sfuLOBIwkyL%58sS_Qs-ga;2RB0RuZP}WRlj*q35{u?vEq$Q z-vF?F^~&1OiZ819qGd74_^Auvd{Y^V0U)0kA03dv090S!3(`=OL-^iSMX^t%)Ldxy zsX`UDe}&9kBwJ!z(av6d)grvTUP|AFZV`iJx+|4Y)I=sJE4m#xVC9%78=hX?%RMye zkHCnO5!5Iqx)!NdqwbE!XS33L8jM5t#N*(J(sfgDVQX9k?Lq=^n^0sC6Eq#f z9I}nR14OrF=5%L9N>G|upvSQo7s3|K1WY~nh{Okf<==rUn!lROAspW78mI|y2O|g! z)|{9lR^k`1%p~o0;VK8#1+?5dRL@+#Yh-qzZsO>w**L*d4jn5GU#x^Kf(8kmCV6Rx z7E3>8ZD1Dl)~tTj_KmyQkmW3`&4G*+<2sY@Hxpmay_$r~pY0lJi)4y}cbbdER50~c&~yp zgIKy~s1mAU?$8gc&y0-2*!d`s-z}2=fza6g$hqv?5QFPtaHXddtB5Bw@kCiXQSI&p zx9cgw&k*1t=wopbCU5=;Y7k%i7>*V=*MMmFMl?Q_N-d7MQ3kG^%$TjMt{qTL!>L8Y zrqiZj=xkW?0OsCDf@p@BblnM2WJQbLRxDnKwagv-Az6Sw#DSF?K)`PYLQmowfs^Zj zldB`&oUR1UYJsyw_hxrQ>pr^{{q{_ydtB=tM`>({o@L*L*t;(Fu3UL~vm!<`F;W&I zTfWX>o8eOF0T^_59sXCS0XcC6I|Rme-$Y(ko)vi2vK2!0HYQE93ILZ{lVzOu?%$?;RV>Q_s6B`u( z{x=a(OSwI7ze4UEEbw_gZ+pPe(4Odsm$+$8vi*bo890fr;cuh@w;ayn=8qr$uSOX~ zM+h4;KJhT60DeMa(?y0-WhHJH;0Y{qfHY*PLKB8mFJhV35J_7PQk4m}KQ)`23fp2e@`>$2{uW9|)9=m@BsI4r93>P~DgGam!gi!R;^j##| zKnE%w^Dz|JLjM2>Dj&tA>E1fRPr_a~9)H0Z!2)G_(a_WPuqhIJnez8QAP_Wuh&uY} zsH4wB9es#8b|-#^&;{558+2_aey4DzMt~Sdq4C6u4AS8cdRFN-Iqcc<2E^bJPTqnf z8t&y4YW15S3E5fSpCg_Sr8ddA#Kl~aD7)ph83ZaU$zE&~YCICy*}LWT<}>X1VSgX< zF@)gh~Q@B`l1(7YBP0xirS*4kosfbtBM(S(rDSA+tV8206sL>R*? zyFzxfH&O63a3F3zS$o}!7C04mI`ss#ZH>G&v_jyvmrM(2w88t%VH{FC{iv(=_QRxGT0OEgLy>jTe`QY%-Bn#vs$^Isva@ zr%a&K<=$JywKnjpk4m5XI$r4;(fUTrVy4HD_7=zz#)%?sp#w+|5$di?1CZ!mWe!er z;*$%AHzfKKY%`9;T%fVdbR_2H;6yW0-zK1?&`}&qV8+bs_Hz@ayaP44!*8_WPob&u zZ$MymzYj)*hVKn z!!x+<87xIBo}-%QXqjC#XJZ&91?#Z`S}XG;zW`Y}GudoPei-qu>OX-SYOvJHWynd8f5UL2UIJO`jogOGIr3 zLXsE%f^7Fq)>-cEAFST2;L5vp`UVsY^9_jl>Yjlwz{x3CJsgo6XDW5m{~0vVMTF;Z zn*Ed3<0<+HROQBAulEg*zMHvcVGIS&G{M>r+%T@^AFVtzA9mU}Gw-uFEqf}X>O6dg zxDUtTUs9Fc#nKML$Z@05uyoyzCr&P3eX zb}bY)7xV%S&BrrII3-n#U5t>)(pg1sYuUlji_lEL88Up#NXuU44T5`vSrQJk3Iq=K z{x3V4-}*X!dFsh1nD@^soyWD#;}zeK<{K*Vo57=6@cf$e?bS-~V=egcBX801Sg7vr z+1P(?{&lDAzU)M@QB|nFP+{xG5?+)96e#&11VIP9JL&dhufkUNFw8JB- z57y=@15s@tS|hycEF%?9e=+*_)JhlZ-IjAJ1IvY#FV}tHvM*fq_ip&#TKB(I8u`uG z-;S*wstlgi22ZcSeXQcYp!qM9*|pVZ_jf0Medf8cdKvcPgQMEuXvKd<^PefRYs=SD z7JCe2%l3TxcHQ$K4l@p#F)A2P;T_De?H+=rmMtJUeqESKXW`T5=U8S=xV zcM}Vd{1hwJfIx6>{AxD4&|5l>=fu6`u{-R(H*tG}V=c$V@KhsK(&xclos~dY@0P#o z@z0*bN=GY!mDVU zg7(T>g2FKbbKfAoic6pQ{UIKl)X1{L695w!V7*GaJrP^9PF z;S4!?B<=$3)#%Kfd(XXd=bm%V`R=*=tha`;BHSj(_EAzbxg{Yc1PW5Pt=q4M!lqrP5IId(FT%t zr2OeXG(hss)V6eEw2|bwR8u+_4U)Vo6-sZ9ZYOznsyQ8whH1)1ouhcqZHo5_-urfn zdIHO!+M=C2eT9l_`x3hxkI-5}d{UCd zctTDK@{KGHC5;#2lR_#b@FQ8E2|1a~MC_VvWJI%HxiG3VoQX@qrP+8!PF_z6A~f)v z6O;TEAuc9vT!KgYwbx&Vn-BTUN^&wCm$S&$cp*ELgy+O%VM-Jv3HznBZ3!_eNwJ&p zRBT2R_+&yx*0b-+LWZv&LZUheGzm!PK!ld3Zzh84FQgQfZoHEV*o;iI9Q@YSO=JEeo6bNc% zo{B+EPkE|H8~OTsp1u#|56wE7%G1!!X7J=~@MOM)r>;<+(BlzCb6ro$F)5K11+71} z>6l}�??wZmc$PF#x@a+kk*ZFp`kEUMrFC9o^mC-@KSOHgQc7L}}szbe_p&yWhyB zW|9+9c2-OX6S5#BCtv~+i3wfhCMIFR24`+*Tr8H%z@TEe&P~TNXw)~to1_Ck=BbLC zYV9mA3k_w?yUKN~a9u^W!o8q!FO;|!%B{Un<*(if{(}D*G@=sb+`t2>5}EUNAaj?J z&3U9?9(npc@IEy8r4HGgNeXO5o}M=Arp=na0!#60yHzSb0R;v1V8%6HRn0-&vy<^$ z_`s-o?A2HxKW>DUw)S&ait zkh6(wN@E2gb5;~!9cxZ7vqE-O)*PAGbaqnGoHqq=GAjuYhi1#nXk02gCC$VWg6P8W z+TNQH_hH7Kjthp0^gw|{MclC|DMnr{<+`^}4U=OBww?qsPkrs6cJzMS_TcjJw&jlZ{3+F zhg!?cE#>V%+bYf8fV)EJ#N{RmGp;c(4Jpkb3i7O&vDi@*AyJ$RWbRUVo79!3rXjDV zt?X!REnC;_oJ~e4S5r+vz|PvNigbk%n}99C>SnPG@*4e?=)?Oqf%vo-j@t05`r15- z)oxttx=k=FslHGhdb?cv<`U3r(HKgHfy`57e@lg;U8l<}9rwCcTZUFzhL+{eZYnLO z)Rt3)SJ(VKuuzvrAMH{6XI1~%5_i^EfMPtEu`EDbWMm$3eV{+8e%~jQp)qSo9kNy{q&T$h`Siv2kqQ zlVDGUg?q(8g?3h)n4_#D$gsGm4C3{%?o#)`3I+G|q|?rh(O`Xnmf8lYnUQOHW)teB z9Kni>X1%4B7fOS}O3N_l80)Q(r*`1D{UgEyBdl_h`eoZ^@cym2%_@C1a3T`$Zzh~s zOwZ7|{0#JPBit)c*J*%!UW9yt1K=48V9yCg+bM9Y z2vhUv1dNHDP%7aN0DsLBvq2bhz2EU6CYlhwmPy9^<4-wUvh*oZMUHu)@fUM{_JGjVw&m%MIyv;H7$f^c_!7(yc#|Uqr zYL2sL$^iyoru;g{JjcS8$CRnb2AMo-9fz?>oOQAzJ~vSHxu&-uqSLtQ{_B?3Qr+T! z^9A%NfJ4zCv>-(s`XU>H_lYPR%^|?{M3Oa^>8Uk0@g%r$y(!^HVfjJQ7a$t>?_Pk0 z8+r{mrFVdU=(n{N9A$qaKxEfzbUEC9&%YWzvJyVBoK(WYYIwNt%35f*8v4PbmdCFs zp>Z`dUKo3#>&w7OXyCzhC3Hj$9VxiW{-#y`z>0t1!L;H(s``(XxT8cNEXx)b3Nc$b z$XJ(c4Pwv14zh-g>dNhT`zZBx*Bp~)^0u3_*s>)~Vt9HAfcmHQ`(P*@qOl-pj$Y`_ zV?2Tx|2Lydv_TGlFcW`2DFKo~f2eT;GYe8oybO=JJ9U6G1W^VSvcthWv*yzUQsvgX zvN&r>RAWH4l7zZfC#E;yF?Zy<7Uz~K@)k6et^fhKp`-x37JL9#@8s3yy#;SM*o|U* z8T>5L!YGJj>14^@57$~p*S#MWbIXSx?R*q^e5urNPU$$OcAQ&u{{iq@iEGi#>?_cD z#7E$Wh_*pIk7OPEz}c*!3=tiWjxg}kG+%X7n1GN@EcOYl`w9eE!@8c1BR>*^EYjZq z0Y6~;RctJKfFKXJSCI$isI5qYgAD>ks8DcUPXbQ%ScL+jBLd@z6zQbV!L}}opXIP+ z>vIr&mT>T!BL-Z&7R2BNd|3f(qQA zp0s}SWa4Qu1%8fy|ISr=fHsme9 zpRRht(DUX7^%%^v6$R#L8FiqVto~6vQtwH@+?ukSq9Mp+%Ir7vWe8RCwj})v8n&e# zqc^G!{cPNp*2^-}ZVhqCXichV-I$ytkZ5WFjLOVUo7w<<*X3&*MH6-#z2n`gjRX~# z@8|9BlKEMr#T`==I2Cj*|Ai=E+Wb!OqdZ&m*7s=Sd4^{{by&Y(0{WYC%sBx^ARO3` z@=hy`nB(M*8k3B>z-p2)I8sg4Z?N`oola7D?qigWEAJ|z@7j>^4r?1&$xIPLnj2Ey zHI3l}k~XxVrh`V^R;|~!S$Ob#-vl#`ds>n{bz0~2JHBfjDnnHLM~v|si@e+90b?Yf zJLnQtFv;mR4Vm!lV(=iBz$U=nH%_ds(NUP>fN(SoHV-vV)L}F`g zZ{4cxYKSS;Tp0uRYCdTO;>@whTjV@OoW^Eol(m3ffu}{}=!e(C9k}|*@I+jt1c+#k zIQW4KuX!c`kLiaDG4Wlj-h+fhLzuJ>&xokQA{nwnu?7Dg8?6fUg!nwQwmyACM6gQ^cgHngj!+}4eCPaurfuD0w0ztd1|YQA$~ zwP|3bX+UY(r#9^?a8Epb)zeW7t$21Ro?YdZw!*9Bww)EK0ebuK*Zs}UNyZ5*L-}skVrF%&29$I{5tuwL~Zu`7_ zP-)+y86z#{s4FpxTOa8EC%i`ITqs zoD%3&1HGj{FG}CH8t7dK^cLS#0t0GbphT{Ui)!eHUMn`EOMUlcI(E=3T-5md>|-!y zh*M!3zW;;g2%T7s-<;)uqXr$oecRU=Zzcn0ROUf2cL#uFeaGvO5BXFMANf|wpL_1hg2 zAk;a)XLI%sm?G3{Oo-7yIKWps-v6a-p%akx8Q_NPZvX?%c*cG!c0))3s*bRlZE(NF zkQfA$$%<)W;fU2VH?=xS{>FVG1t;x^Ym_C@;NPv4eS8_4cfIdy2gemp## zvk`xYJh)GX&3`5Vr$v$;!(nm$B+a9T3rPG%Bmo3CU=wb}#O#kXhLob$e5d9>>SH_4 zu?sOK_Cx4J`cEJ*iBB-P!??OIWZ+#5A6y9^Tt2FVPpaXQn+10;ywh12`@-M&;o&<+ zm#*IXk$74$Pyc$0Lfv@0LWWnmxPd$;@7m>XW-dMi& zX!IX0Dv?*!$g5z#o83=QoL?a+c;GO7dl&+Z;kKneDTN~#x9BIN@Nzi}!3^35k3bfg za-elJ(6bT%c`Je4YG8MXTxeYT;rK7mxR6_d1ExOJ-&f-LbjxRMVqhL{r4!5dFOd0} z4TF%=kkb=S*Ontszf0PQXX{ViYtg9sh*P~)E7;&fz;>I<+qVcduKU%5)lOR1z4hq| zwfRVRJ8z$|t>=Zv7C~zkvw_t^o;mZ;Ls1)s!N+Y zk&34$`S_{a==U$|^){q_f$P2eE%uL=5;PNzYp?!Yc)ofnH35Ziddx~U;$VGnorymO zqA_S1@Zfkxd;?jkhY>$P1{|%9yHIG|-S1S_%RD^g4*u`Q`w-emlR#i}WsiU1*kXL) zQg&WecV1RHuc)0@ zl%}g{)76sq>KE8@Nh*#kO)E_iwJB2aMs(|BCrEZ|TqKb<&!&+#+-T^Kn%J)qe0&+n zJdy+w9*BgFVpvC<*Rc>u94r1D2z-$7SFt(Sc8q_u7k%WmHEHZ;2kTDY;Hg4c$yGYY z?y68$vX^1^;KAV5gnT_2VWdm^GZ>#{k54AZ0j9=gX45mbbW6*EQ-b^ua^SC+W^N7g zaXC(=LksH#`nF{-J_Wz>nu4tY{1O2ng1$#UkEw;Kj2I(}!LNw)<8TqdgVvze8NYMT zf3M(`-c7;p%QAzL@KYfY|I~s;Q7V~%Dm}tPwjf%wQAL7qbuU0;PKKRG=))!gBqy&V z3W^fa>}Os&t1%Zxke<2p(pj7%Q6yC}DXZ@4JcGijQIE@=8ZMUc{{XuOg33VGcKeQ>>Z3(8xaqU%O#cMmW zYt#l%DyX1Bg|-=laww<*R1h3Em>+>aLM0a@4lb(HvJw&pPMmr~2npVsbx1PaeEa6T zo%jDH?`Gb}bbQ?2?iOIZG?P+;eK$Hnyvqp3|#! zV=*vJOT~D#3Z%*s6V!L>nEBX+jkCf^#>Cu(zCJWDonrd5}bzM~V$|rH)#j z<~7|E=hch!5w5kUmP?xH)a&PR$lL`?-O0_yU>&9rR+J8so$Nj7GHB{>gd2xkpObF% z-|R>pJ$gEwJo;QNl}?{b)6lXf89+lky#Kvt8Bc=jW3PH{f|LE;r|tU{_(o|nZN*Z- zFvZ1EO%dr%)F?s{O^@Xz`lIWZ>{#p)`0=6tG4c%i)t>}EdIJVTdOa{0$76Ia0ymrl zO1D{O;1KdB*!_S)W*F((3-SXBaxEWB58r83 z$Wx9U;__1^O@&gVD6LYfh9UN?d|+uSNEPq8A7*?~G;FDMBx=nTN3Z(}pCTF^!gH|oLhHptNTJz)}K z_j|r(E1?l0v3x*ge};Akv2d+5DA3p|u9t;dmh?z?dr)4J2?9t*Hb|GleJw4k;bYLQ zo=8GI4qq=iW6vAa>U@lzfE}G?CnEY1udj#tk%I;)R)k_Ht?1mkx|U%c2=$-{vrk6> zOjqhcL9H33YDIBd`Q#PG;%-oS7Pow^a(KvY7*H6KsJah_y_GI_LGR4=CLD{v)aThd zy-yAP?)6`rd~@>p++A<~J#T;0+y9#!T{->9%nzxiJat!|x+hODNA#LUbcG)Z7O-2< zL+J-{=t^q+_ZO^(PLSi^x*H_eLXS^xJ^+7QAe)91nVc|~LJtm2F;=M;k#XFz63Hzta!qJ2zayZXK&XV@Qoj~*PsDBSD)W;)ah*ak2hytn5-PRkgDH53pSaw_cj zmJqqk)Ga4$QDCrV7gkHX=_*xpBVQ>% zJJ#TX%|;f0vl!giZi%(?pEBRplhARqTklIaSV|iWvnL079NfDCJ3A1Ke2EINRF?$E zQ}56f+tto1@S{<-Jn&P*XdQzWTnegMTfaN%@f|(i;b|NLIPDdD^3|oLw1q|+U&J#6 zawiQ94)C~myx2er;Th8j(5)g301DXZkI+rW$kAV&Yt0SFPx{&CP!NEj?i+Xq3JV_As}) i4e+7JmUq0C2sDM(^FUCVo1X{5P;>Lc4l8Nit^Wa)0o%9$ delta 1926 zcmZuxYitx%6rQ`YkJ(pux1DYGQMTLda}bw;5(@Gv?b1|0aLbC8#AUiW(v^LbJF^BF zwFEIGgh=rkHPI0M_&|P{8Z}W91O9-Rki|5SOeEot@u&PU#Q4K=?kvYCp zzH`njOuj$a@|oZ75nv^5@78|nywnmQS1t|Y_f!G}8xbnOsQ?j*LeW(YRa&Q7`Cclw zRpcp|2sU9{D7w!IMUN(bDhk3i_*-i#TqKi%;=OeHouwJlZX+#hUhH!QKuUw`3vnXW zf?Mm1yATo?NtpTV1_`n=_J2s2%{x*XBEUAqihfMD!7iDkQDivsdaZ8EwT{zLaZ*!h zVP+(g=Da@ki(`Vs*^oQBw8^=NkU@6bwVp)TN3Jn4#O}L>K^T>iTexB8tV-32Zi>gX zQ*398mXhJv1O{a~^g@kl6w7bg4Xs+Fdtj^kfE+yv=voUf1_nJ0 zWG=X>!mpH!ol9?p$8C`G&mB>+mi^xGjqO#5y%)(4H#^}Bvp*y0Ft$>2^$IlkhWmBl ziX|P7_J`Xa$pitU+y5i|Alliqax0pFE)2xd@?N+-=Sm(iYPI7@x(`ltKYKo=&-1GJ zXb?HPfe~kl1+wz<=$etYkl?A99t{CejZ&>@O3l4SQ>d^Vi0RIpENHVRRy-D(ypX%_ z!itw8cmx2u1^O~-ISd2pViHZqL2e=lFKCg;c-nain0l2h#P<%|^@T1@oS#^nx#R0v z_H{LUUH9a~vOMs)<6G%wy&>oB$hl=X$Nq|6^xECg2ZBY|V&^m2`*P%5X61PZ>wz2O z`1T3G_Qm>#S|iZYkb3l9@OXJ-=kA1XJ>ePii`NI^qoR02^uzu}@ToCL{7G_*`PrH^ zUy~eLUNbbj4fF6!Yr2uI)L{Hfr*;yX1IB=NaS7zKf{9q|54j(3ytrno|Pr!xMZYXf( z!mHZwk^a?pyAMDo9_;>A!U0ljG{V-WI-ESbX|^xb8T$(Dq42cuz$3)_&e<;n&cQ?D z>`dxbtlm_E7tFt=&8~bVw4TAv?@$}31GE8V`NDidN>H9G;mQ8cvy}jf4|{u4JU^Z{ zc9TLlb~=H56R#@U(bq|Cup@n&=h1{ho>Rykg&b0%Gzn+ZlSlmPWi6jKefj)}x>~NP z6*$ph5c4KC9Tl}uC{iynMv+KJoJbr<@KZr!NO)tn;SL9m;wO~P0i6(pEds@Er8Oi7 zFr7w%&B{weakzM%R;m4XL=sqfwm(b!7(PEM?KI1-4!k&z5tl(uU R|LzJ(WA%@1P_lN+;$Oo@(Zc`$ diff --git a/quacc/method/base.py b/quacc/method/base.py index c36636b..8a51362 100644 --- a/quacc/method/base.py +++ b/quacc/method/base.py @@ -1,15 +1,13 @@ import math from abc import abstractmethod +from copy import deepcopy +from typing import List import numpy as np -import quapy as qp from quapy.data import LabelledCollection -from quapy.method.aggregative import CC, SLD, BaseQuantifier -from quapy.model_selection import GridSearchQ -from quapy.protocol import UPP +from quapy.method.aggregative import BaseQuantifier +from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator -from sklearn.linear_model import LogisticRegression -from sklearn.model_selection import cross_val_predict from quacc.data import ExtendedCollection @@ -20,9 +18,7 @@ class BaseAccuracyEstimator(BaseQuantifier): classifier: BaseEstimator, quantifier: BaseQuantifier, ): - self.fit_score = None self.__check_classifier(classifier) - self.classifier = classifier self.quantifier = quantifier def __check_classifier(self, classifier): @@ -30,21 +26,7 @@ class BaseAccuracyEstimator(BaseQuantifier): raise ValueError( f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities." ) - - def _gs_params(self, t_val: LabelledCollection): - return { - "param_grid": { - "classifier__C": np.logspace(-3, 3, 7), - "classifier__class_weight": [None, "balanced"], - "recalib": [None, "bcts"], - }, - "protocol": UPP(t_val, repeats=1000), - "error": qp.error.mae, - "refit": False, - "timeout": -1, - "n_jobs": None, - "verbose": True, - } + self.classifier = classifier def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: if not pred_proba: @@ -67,6 +49,7 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): quantifier: BaseQuantifier, ): super().__init__(classifier, quantifier) + self.e_train = None def fit(self, train: LabelledCollection): pred_probs = self.classifier.predict_proba(train.X) @@ -95,84 +78,52 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): - def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None): - super().__init__() - self.c_model = c_model - self._q_model_name = q_model.upper() - self.q_models = [] - self.gs = gs - self.recalib = recalib - self.e_train = None + def __init__(self, classifier: BaseEstimator, quantifier: BaseAccuracyEstimator): + super().__init__(classifier, quantifier) + self.quantifiers = [] + self.e_trains = [] def fit(self, train: LabelledCollection | ExtendedCollection): - # check if model is fit - # self.model.fit(*train.Xy) - if isinstance(train, LabelledCollection): - pred_prob_train = cross_val_predict( - self.c_model, *train.Xy, method="predict_proba" - ) - - self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train) - elif isinstance(train, ExtendedCollection): - self.e_train = train + pred_probs = self.classifier.predict_proba(train.X) + self.e_train = ExtendedCollection.extend_collection(train, pred_probs) self.n_classes = self.e_train.n_classes - e_trains = self.e_train.split_by_pred() + self.e_trains = self.e_train.split_by_pred() + self.quantifiers = [deepcopy(self.quantifier) for _ in self.e_trains] - if self._q_model_name == "SLD": - fit_scores = [] - for e_train in e_trains: - if self.gs: - t_train, t_val = e_train.split_stratified(0.6, random_state=0) - gs_params = self._gs_params(t_val) - q_model = GridSearchQ( - SLD(LogisticRegression()), - **gs_params, - ) - q_model.fit(t_train) - fit_scores.append(q_model.best_score_) - self.q_models.append(q_model) - else: - q_model = SLD(LogisticRegression(), recalib=self.recalib) - q_model.fit(e_train) - self.q_models.append(q_model) - - if self.gs: - self.fit_score = np.mean(fit_scores) - - elif self._q_model_name == "CC": - for e_train in e_trains: - q_model = CC(LogisticRegression()) - q_model.fit(e_train) - self.q_models.append(q_model) + self.quantifiers = [] + for train in self.e_trains: + quant = deepcopy(self.quantifier) + quant.fit(train) + self.quantifiers.append(quant) def estimate(self, instances, ext=False): # TODO: test + e_inst = instances if not ext: - pred_prob = self.c_model.predict_proba(instances) + pred_prob = self.classifier.predict_proba(instances) e_inst = ExtendedCollection.extend_instances(instances, pred_prob) - else: - e_inst = instances _ncl = int(math.sqrt(self.n_classes)) s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst) - estim_prevs = [ - self._quantify_helper(inst, norm, q_model) - for (inst, norm, q_model) in zip(s_inst, norms, self.q_models) - ] + estim_prevs = self._quantify_helper(s_inst, norms) - estim_prev = [] - for prev_row in zip(*estim_prevs): - for prev in prev_row: - estim_prev.append(prev) + estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten() + return estim_prev - return np.asarray(estim_prev) + def _quantify_helper( + self, + s_inst: List[np.ndarray | csr_matrix], + norms: List[float], + ): + estim_prevs = [] + for quant, inst, norm in zip(self.quantifiers, s_inst, norms): + if inst.shape[0] > 0: + estim_prevs.append(quant.quantify(inst) * norm) + else: + estim_prevs.append(np.asarray([0.0, 0.0])) - def _quantify_helper(self, inst, norm, q_model): - if inst.shape[0] > 0: - return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst)))) - else: - return np.asarray([0.0, 0.0]) + return estim_prevs BAE = BaseAccuracyEstimator diff --git a/quacc/method/model_selection.py b/quacc/method/model_selection.py index a80d5d9..ba866f6 100644 --- a/quacc/method/model_selection.py +++ b/quacc/method/model_selection.py @@ -7,8 +7,9 @@ from quapy.data import LabelledCollection from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol import quacc as qc -import quacc.evaluation.method as evaluation +import quacc.error from quacc.data import ExtendedCollection +from quacc.evaluation import evaluate from quacc.method.base import BaseAccuracyEstimator @@ -138,8 +139,9 @@ class GridSearchAE(BaseAccuracyEstimator): model = deepcopy(self.model) # overrides default parameters with the parameters being explored at this iteration model.set_params(**params) + # print({k: v for k, v in model.get_params().items() if k in params}) model.fit(training) - score = evaluation.evaluate(model, protocol=protocol, error_metric=error) + score = evaluate(model, protocol=protocol, error_metric=error) ttime = time() - tinit self._sout( @@ -157,7 +159,6 @@ class GridSearchAE(BaseAccuracyEstimator): except Exception as e: self._sout(f"something went wrong for config {params}; skipping:") self._sout(f"\tException: {e}") - # traceback(e) score = None return params, score, model