diff --git a/.vscode/vscode-kanban.json b/.vscode/vscode-kanban.json index 88a249c..7b6f95c 100644 --- a/.vscode/vscode-kanban.json +++ b/.vscode/vscode-kanban.json @@ -1,14 +1,5 @@ { "todo": [ - { - "assignedTo": { - "name": "Lorenzo Volpi" - }, - "creation_time": "2023-10-28T14:34:46.226Z", - "id": "4", - "references": [], - "title": "Aggingere estimator basati su PACC (quantificatore)" - }, { "assignedTo": { "name": "Lorenzo Volpi" @@ -18,15 +9,6 @@ "references": [], "title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence" }, - { - "assignedTo": { - "name": "Lorenzo Volpi" - }, - "creation_time": "2023-10-28T14:34:23.217Z", - "id": "3", - "references": [], - "title": "Relaizzare grid search per task specifico partedno da GridSearchQ" - }, { "assignedTo": { "name": "Lorenzo Volpi" @@ -38,6 +20,27 @@ } ], "in-progress": [ + { + "assignedTo": { + "name": "Lorenzo Volpi" + }, + "creation_time": "2023-10-28T14:34:23.217Z", + "id": "3", + "references": [], + "title": "Relaizzare grid search per task specifico partedno da GridSearchQ" + }, + { + "assignedTo": { + "name": "Lorenzo Volpi" + }, + "creation_time": "2023-10-28T14:34:46.226Z", + "id": "4", + "references": [], + "title": "Aggingere estimator basati su PACC (quantificatore)" + } + ], + "testing": [], + "done": [ { "assignedTo": { "name": "Lorenzo Volpi" @@ -47,7 +50,5 @@ "references": [], "title": "Rework rappresentazione dati di report" } - ], - "testing": [], - "done": [] + ] } \ No newline at end of file diff --git a/conf.yaml b/conf.yaml index 8cf0b81..a75a718 100644 --- a/conf.yaml +++ b/conf.yaml @@ -12,8 +12,9 @@ debug_conf: &debug_conf plot_confs: debug: PLOT_ESTIMATORS: - - mul_sld_gs + - mul_sld - ref + - atc_mc PLOT_STDEV: true test_conf: &test_conf @@ -21,24 +22,28 @@ test_conf: &test_conf METRICS: - acc - f1 - DATASET_N_PREVS: 2 - DATASET_PREVS: - - 0.5 - - 0.1 + DATASET_N_PREVS: 9 confs: - # - DATASET_NAME: rcv1 - # DATASET_TARGET: CCAT - - DATASET_NAME: imdb + - DATASET_NAME: rcv1 + DATASET_TARGET: CCAT + # - DATASET_NAME: imdb plot_confs: - best_vs_atc: + 2gs_vs_atc: + PLOT_ESTIMATORS: + - bin_sld_gs + - bin_sld_qgs + - mul_sld_gs + - mul_sld_qgs + - ref + - atc_mc + - atc_ne + sld_vs_pacc: PLOT_ESTIMATORS: - bin_sld - - bin_sld_bcts - bin_sld_gs - mul_sld - - mul_sld_bcts - mul_sld_gs - ref - atc_mc @@ -102,4 +107,4 @@ main_conf: &main_conf - atc_ne - doc_feat -exec: *debug_conf \ No newline at end of file +exec: *test_conf \ No newline at end of file diff --git a/quacc.log b/quacc.log index ffe98ee..df47722 100644 --- a/quacc.log +++ b/quacc.log @@ -1494,3 +1494,97 @@ 01/11/23 13:07:27| INFO Dataset sample 0.50 of dataset imdb_1prevs started 01/11/23 13:07:27| ERROR Evaluation over imdb_1prevs failed. Exception: 'Invalid estimator: estimator mul_sld_gs does not exist' 01/11/23 13:07:27| ERROR Failed while saving configuration imdb_debug of imdb_1prevs. Exception: cannot access local variable 'dr' where it is not associated with a value +---------------------------------------------------------------------------------------------------- +03/11/23 20:54:19| INFO dataset rcv1_CCAT_9prevs +03/11/23 20:54:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +03/11/23 20:54:28| WARNING Method mul_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib']. +03/11/23 20:54:29| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor' +03/11/23 20:54:30| WARNING Method bin_sld_gs failed. Exception: Invalid parameter 'quantifier' for estimator EMQ(classifier=LogisticRegression()). Valid parameters are: ['classifier', 'exact_train_prev', 'recalib']. +03/11/23 20:55:09| INFO ref finished [took 38.5179s] +---------------------------------------------------------------------------------------------------- +03/11/23 21:28:36| INFO dataset rcv1_CCAT_9prevs +03/11/23 21:28:41| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +03/11/23 21:28:45| WARNING Method mul_sld failed. Exception: evaluation_report() got an unexpected keyword argument 'protocor' +---------------------------------------------------------------------------------------------------- +03/11/23 21:31:03| INFO dataset rcv1_CCAT_9prevs +03/11/23 21:31:08| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +03/11/23 21:31:59| INFO ref finished [took 45.6616s] +03/11/23 21:32:03| INFO atc_mc finished [took 48.4360s] +03/11/23 21:32:07| INFO atc_ne finished [took 51.0515s] +03/11/23 21:32:23| INFO mul_sld finished [took 72.9229s] +03/11/23 21:34:43| INFO bin_sld finished [took 213.9538s] +03/11/23 21:36:27| INFO mul_sld_gs finished [took 314.9357s] +03/11/23 21:40:50| INFO bin_sld_gs finished [took 579.2530s] +03/11/23 21:40:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 582.5876s] +03/11/23 21:40:50| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started +03/11/23 21:41:39| INFO ref finished [took 43.7409s] +03/11/23 21:41:43| INFO atc_mc finished [took 46.4580s] +03/11/23 21:41:44| INFO atc_ne finished [took 46.4267s] +03/11/23 21:41:54| INFO mul_sld finished [took 61.3005s] +03/11/23 21:44:18| INFO bin_sld finished [took 206.3680s] +03/11/23 21:45:59| INFO mul_sld_gs finished [took 304.4726s] +03/11/23 21:50:33| INFO bin_sld_gs finished [took 579.3455s] +03/11/23 21:50:33| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 582.4808s] +03/11/23 21:50:33| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started +03/11/23 21:51:22| INFO ref finished [took 43.6853s] +03/11/23 21:51:26| INFO atc_mc finished [took 47.1366s] +03/11/23 21:51:30| INFO atc_ne finished [took 49.4868s] +03/11/23 21:51:34| INFO mul_sld finished [took 59.0964s] +03/11/23 21:53:59| INFO bin_sld finished [took 205.0248s] +03/11/23 21:55:50| INFO mul_sld_gs finished [took 312.5630s] +03/11/23 22:00:27| INFO bin_sld_gs finished [took 591.1460s] +03/11/23 22:00:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 594.3163s] +03/11/23 22:00:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started +03/11/23 22:01:15| INFO ref finished [took 43.3806s] +03/11/23 22:01:19| INFO atc_mc finished [took 46.6674s] +03/11/23 22:01:21| INFO atc_ne finished [took 47.1220s] +03/11/23 22:01:28| INFO mul_sld finished [took 58.6799s] +03/11/23 22:03:53| INFO bin_sld finished [took 204.7659s] +03/11/23 22:05:39| INFO mul_sld_gs finished [took 307.8811s] +03/11/23 22:10:32| INFO bin_sld_gs finished [took 601.9995s] +03/11/23 22:10:32| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 604.8406s] +03/11/23 22:10:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started +03/11/23 22:11:20| INFO ref finished [took 42.8256s] +03/11/23 22:11:25| INFO atc_mc finished [took 46.9203s] +03/11/23 22:11:28| INFO atc_ne finished [took 49.3042s] +03/11/23 22:11:34| INFO mul_sld finished [took 60.2744s] +03/11/23 22:13:59| INFO bin_sld finished [took 205.7078s] +03/11/23 22:15:45| INFO mul_sld_gs finished [took 309.0888s] +03/11/23 22:20:32| INFO bin_sld_gs finished [took 596.5102s] +03/11/23 22:20:32| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 599.5067s] +03/11/23 22:20:32| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started +03/11/23 22:21:20| INFO ref finished [took 43.1698s] +03/11/23 22:21:24| INFO atc_mc finished [took 46.5768s] +03/11/23 22:21:25| INFO atc_ne finished [took 46.3408s] +03/11/23 22:21:34| INFO mul_sld finished [took 60.8070s] +03/11/23 22:23:58| INFO bin_sld finished [took 205.3362s] +03/11/23 22:25:44| INFO mul_sld_gs finished [took 308.1859s] +03/11/23 22:30:44| INFO bin_sld_gs finished [took 609.5468s] +03/11/23 22:30:44| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 612.5803s] +03/11/23 22:30:44| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started +03/11/23 22:31:32| INFO ref finished [took 43.2949s] +03/11/23 22:31:37| INFO atc_mc finished [took 46.3686s] +03/11/23 22:31:40| INFO atc_ne finished [took 49.2242s] +03/11/23 22:31:47| INFO mul_sld finished [took 60.9437s] +03/11/23 22:34:11| INFO bin_sld finished [took 205.9299s] +03/11/23 22:35:56| INFO mul_sld_gs finished [took 308.2738s] +03/11/23 22:40:36| INFO bin_sld_gs finished [took 588.7918s] +03/11/23 22:40:36| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 591.8830s] +03/11/23 22:40:36| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started +03/11/23 22:41:24| INFO ref finished [took 43.3321s] +03/11/23 22:41:29| INFO atc_mc finished [took 46.8041s] +03/11/23 22:41:29| INFO atc_ne finished [took 46.5810s] +03/11/23 22:41:38| INFO mul_sld finished [took 60.2962s] +03/11/23 22:44:07| INFO bin_sld finished [took 209.6435s] +03/11/23 22:45:44| INFO mul_sld_gs finished [took 304.4809s] +03/11/23 22:50:39| INFO bin_sld_gs finished [took 599.5588s] +03/11/23 22:50:39| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 602.5720s] +03/11/23 22:50:39| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started +03/11/23 22:51:26| INFO ref finished [took 42.4313s] +03/11/23 22:51:30| INFO atc_mc finished [took 45.5261s] +03/11/23 22:51:34| INFO atc_ne finished [took 48.4488s] +03/11/23 22:51:47| INFO mul_sld finished [took 66.4801s] +03/11/23 22:54:08| INFO bin_sld finished [took 208.4272s] +03/11/23 22:55:49| INFO mul_sld_gs finished [took 306.4505s] +03/11/23 23:00:15| INFO bin_sld_gs finished [took 573.7761s] +03/11/23 23:00:15| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 576.7586s] diff --git a/quacc/evaluation/__init__.py b/quacc/evaluation/__init__.py index e69de29..1851c4b 100644 --- a/quacc/evaluation/__init__.py +++ b/quacc/evaluation/__init__.py @@ -0,0 +1,34 @@ +from typing import Callable, Union + +import numpy as np +from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol + +import quacc as qc + +from ..method.base import BaseAccuracyEstimator + + +def evaluate( + estimator: BaseAccuracyEstimator, + protocol: AbstractProtocol, + error_metric: Union[Callable | str], +) -> float: + if isinstance(error_metric, str): + error_metric = qc.error.from_name(error_metric) + + collator_bck_ = protocol.collator + protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") + + estim_prevs, true_prevs = [], [] + for sample in protocol(): + e_sample = estimator.extend(sample) + estim_prev = estimator.estimate(e_sample.X, ext=True) + estim_prevs.append(estim_prev) + true_prevs.append(e_sample.prevalence()) + + protocol.collator = collator_bck_ + + true_prevs = np.array(true_prevs) + estim_prevs = np.array(estim_prevs) + + return error_metric(true_prevs, estim_prevs) diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index f08bd0b..d50ccab 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -1,9 +1,9 @@ +import inspect from functools import wraps -from typing import Callable, Union import numpy as np from quapy.method.aggregative import SLD -from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol +from quapy.protocol import UPP, AbstractProtocol from sklearn.linear_model import LogisticRegression import quacc as qc @@ -25,38 +25,12 @@ def method(func): return wrapper -def evaluate( - estimator: BaseAccuracyEstimator, - protocol: AbstractProtocol, - error_metric: Union[Callable | str], -) -> float: - if isinstance(error_metric, str): - error_metric = qc.error.from_name(error_metric) - - collator_bck_ = protocol.collator - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - - estim_prevs, true_prevs = [], [] - for sample in protocol(): - e_sample = estimator.extend(sample) - estim_prev = estimator.estimate(e_sample.X, ext=True) - estim_prevs.append(estim_prev) - true_prevs.append(e_sample.prevalence()) - - protocol.collator = collator_bck_ - - true_prevs = np.array(true_prevs) - estim_prevs = np.array(estim_prevs) - - return error_metric(true_prevs, estim_prevs) - - def evaluation_report( estimator: BaseAccuracyEstimator, protocol: AbstractProtocol, - method: str, ) -> EvaluationReport: - report = EvaluationReport(name=method) + method_name = inspect.stack()[1].function + report = EvaluationReport(name=method_name) for sample in protocol(): e_sample = estimator.extend(sample) estim_prev = estimator.estimate(e_sample.X, ext=True) @@ -80,7 +54,6 @@ def bin_sld(c_model, validation, protocol) -> EvaluationReport: return evaluation_report( estimator=est, protocol=protocol, - method="bin_sld", ) @@ -90,8 +63,7 @@ def mul_sld(c_model, validation, protocol) -> EvaluationReport: est.fit(validation) return evaluation_report( estimator=est, - protocor=protocol, - method="mul_sld", + protocol=protocol, ) @@ -102,7 +74,6 @@ def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport: return evaluation_report( estimator=est, protocol=protocol, - method="bin_sld_bcts", ) @@ -113,14 +84,13 @@ def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport: return evaluation_report( estimator=est, protocol=protocol, - method="mul_sld_bcts", ) @method -def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: +def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport: v_train, v_val = validation.split_stratified(0.6, random_state=0) - model = SLD(LogisticRegression()) + model = BQAE(c_model, SLD(LogisticRegression())) est = GridSearchAE( model=model, param_grid={ @@ -130,10 +100,30 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: }, refit=False, protocol=UPP(v_val, repeats=100), - verbose=True, + verbose=False, + ).fit(v_train) + return evaluation_report( + estimator=est, + protocol=protocol, + ) + + +@method +def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: + v_train, v_val = validation.split_stratified(0.6, random_state=0) + model = MCAE(c_model, SLD(LogisticRegression())) + est = GridSearchAE( + model=model, + param_grid={ + "q__classifier__C": np.logspace(-3, 3, 7), + "q__classifier__class_weight": [None, "balanced"], + "q__recalib": [None, "bcts", "vs"], + }, + refit=False, + protocol=UPP(v_val, repeats=100), + verbose=False, ).fit(v_train) return evaluation_report( estimator=est, protocol=protocol, - method="mul_sld_gs", ) diff --git a/quacc/main_test.py b/quacc/main_test.py index 7239908..ac8a9bd 100644 --- a/quacc/main_test.py +++ b/quacc/main_test.py @@ -10,7 +10,7 @@ from sklearn.linear_model import LogisticRegression from quacc.dataset import Dataset from quacc.error import acc from quacc.evaluation.report import CompReport, EvaluationReport -from quacc.method.base import MultiClassAccuracyEstimator +from quacc.method.base import BinaryQuantifierAccuracyEstimator from quacc.method.model_selection import GridSearchAE @@ -21,8 +21,8 @@ def test_gs(): classifier.fit(*d.train.Xy) quantifier = SLD(LogisticRegression()) - estimator = MultiClassAccuracyEstimator(classifier, quantifier) - estimator.fit(d.validation) + # estimator = MultiClassAccuracyEstimator(classifier, quantifier) + estimator = BinaryQuantifierAccuracyEstimator(classifier, quantifier) v_train, v_val = d.validation.split_stratified(0.6, random_state=0) gs_protocol = UPP(v_val, sample_size=1000, repeats=100) @@ -31,13 +31,15 @@ def test_gs(): param_grid={ "q__classifier__C": np.logspace(-3, 3, 7), "q__classifier__class_weight": [None, "balanced"], - "q__recalib": [None, "bcts", "vs"], + "q__recalib": [None, "bcts", "ts"], }, refit=False, protocol=gs_protocol, verbose=True, ).fit(v_train) + estimator.fit(d.validation) + tstart = time() erb, ergs = EvaluationReport("base"), EvaluationReport("gs") protocol = APP( diff --git a/quacc/method/__pycache__/base.cpython-311.pyc b/quacc/method/__pycache__/base.cpython-311.pyc index 220398a..429ef2f 100644 Binary files a/quacc/method/__pycache__/base.cpython-311.pyc and b/quacc/method/__pycache__/base.cpython-311.pyc differ diff --git a/quacc/method/__pycache__/model_selection.cpython-311.pyc b/quacc/method/__pycache__/model_selection.cpython-311.pyc index 03b8910..2fdba77 100644 Binary files a/quacc/method/__pycache__/model_selection.cpython-311.pyc and b/quacc/method/__pycache__/model_selection.cpython-311.pyc differ diff --git a/quacc/method/base.py b/quacc/method/base.py index c36636b..8a51362 100644 --- a/quacc/method/base.py +++ b/quacc/method/base.py @@ -1,15 +1,13 @@ import math from abc import abstractmethod +from copy import deepcopy +from typing import List import numpy as np -import quapy as qp from quapy.data import LabelledCollection -from quapy.method.aggregative import CC, SLD, BaseQuantifier -from quapy.model_selection import GridSearchQ -from quapy.protocol import UPP +from quapy.method.aggregative import BaseQuantifier +from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator -from sklearn.linear_model import LogisticRegression -from sklearn.model_selection import cross_val_predict from quacc.data import ExtendedCollection @@ -20,9 +18,7 @@ class BaseAccuracyEstimator(BaseQuantifier): classifier: BaseEstimator, quantifier: BaseQuantifier, ): - self.fit_score = None self.__check_classifier(classifier) - self.classifier = classifier self.quantifier = quantifier def __check_classifier(self, classifier): @@ -30,21 +26,7 @@ class BaseAccuracyEstimator(BaseQuantifier): raise ValueError( f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities." ) - - def _gs_params(self, t_val: LabelledCollection): - return { - "param_grid": { - "classifier__C": np.logspace(-3, 3, 7), - "classifier__class_weight": [None, "balanced"], - "recalib": [None, "bcts"], - }, - "protocol": UPP(t_val, repeats=1000), - "error": qp.error.mae, - "refit": False, - "timeout": -1, - "n_jobs": None, - "verbose": True, - } + self.classifier = classifier def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: if not pred_proba: @@ -67,6 +49,7 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): quantifier: BaseQuantifier, ): super().__init__(classifier, quantifier) + self.e_train = None def fit(self, train: LabelledCollection): pred_probs = self.classifier.predict_proba(train.X) @@ -95,84 +78,52 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): - def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None): - super().__init__() - self.c_model = c_model - self._q_model_name = q_model.upper() - self.q_models = [] - self.gs = gs - self.recalib = recalib - self.e_train = None + def __init__(self, classifier: BaseEstimator, quantifier: BaseAccuracyEstimator): + super().__init__(classifier, quantifier) + self.quantifiers = [] + self.e_trains = [] def fit(self, train: LabelledCollection | ExtendedCollection): - # check if model is fit - # self.model.fit(*train.Xy) - if isinstance(train, LabelledCollection): - pred_prob_train = cross_val_predict( - self.c_model, *train.Xy, method="predict_proba" - ) - - self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train) - elif isinstance(train, ExtendedCollection): - self.e_train = train + pred_probs = self.classifier.predict_proba(train.X) + self.e_train = ExtendedCollection.extend_collection(train, pred_probs) self.n_classes = self.e_train.n_classes - e_trains = self.e_train.split_by_pred() + self.e_trains = self.e_train.split_by_pred() + self.quantifiers = [deepcopy(self.quantifier) for _ in self.e_trains] - if self._q_model_name == "SLD": - fit_scores = [] - for e_train in e_trains: - if self.gs: - t_train, t_val = e_train.split_stratified(0.6, random_state=0) - gs_params = self._gs_params(t_val) - q_model = GridSearchQ( - SLD(LogisticRegression()), - **gs_params, - ) - q_model.fit(t_train) - fit_scores.append(q_model.best_score_) - self.q_models.append(q_model) - else: - q_model = SLD(LogisticRegression(), recalib=self.recalib) - q_model.fit(e_train) - self.q_models.append(q_model) - - if self.gs: - self.fit_score = np.mean(fit_scores) - - elif self._q_model_name == "CC": - for e_train in e_trains: - q_model = CC(LogisticRegression()) - q_model.fit(e_train) - self.q_models.append(q_model) + self.quantifiers = [] + for train in self.e_trains: + quant = deepcopy(self.quantifier) + quant.fit(train) + self.quantifiers.append(quant) def estimate(self, instances, ext=False): # TODO: test + e_inst = instances if not ext: - pred_prob = self.c_model.predict_proba(instances) + pred_prob = self.classifier.predict_proba(instances) e_inst = ExtendedCollection.extend_instances(instances, pred_prob) - else: - e_inst = instances _ncl = int(math.sqrt(self.n_classes)) s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst) - estim_prevs = [ - self._quantify_helper(inst, norm, q_model) - for (inst, norm, q_model) in zip(s_inst, norms, self.q_models) - ] + estim_prevs = self._quantify_helper(s_inst, norms) - estim_prev = [] - for prev_row in zip(*estim_prevs): - for prev in prev_row: - estim_prev.append(prev) + estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten() + return estim_prev - return np.asarray(estim_prev) + def _quantify_helper( + self, + s_inst: List[np.ndarray | csr_matrix], + norms: List[float], + ): + estim_prevs = [] + for quant, inst, norm in zip(self.quantifiers, s_inst, norms): + if inst.shape[0] > 0: + estim_prevs.append(quant.quantify(inst) * norm) + else: + estim_prevs.append(np.asarray([0.0, 0.0])) - def _quantify_helper(self, inst, norm, q_model): - if inst.shape[0] > 0: - return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst)))) - else: - return np.asarray([0.0, 0.0]) + return estim_prevs BAE = BaseAccuracyEstimator diff --git a/quacc/method/model_selection.py b/quacc/method/model_selection.py index a80d5d9..ba866f6 100644 --- a/quacc/method/model_selection.py +++ b/quacc/method/model_selection.py @@ -7,8 +7,9 @@ from quapy.data import LabelledCollection from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol import quacc as qc -import quacc.evaluation.method as evaluation +import quacc.error from quacc.data import ExtendedCollection +from quacc.evaluation import evaluate from quacc.method.base import BaseAccuracyEstimator @@ -138,8 +139,9 @@ class GridSearchAE(BaseAccuracyEstimator): model = deepcopy(self.model) # overrides default parameters with the parameters being explored at this iteration model.set_params(**params) + # print({k: v for k, v in model.get_params().items() if k in params}) model.fit(training) - score = evaluation.evaluate(model, protocol=protocol, error_metric=error) + score = evaluate(model, protocol=protocol, error_metric=error) ttime = time() - tinit self._sout( @@ -157,7 +159,6 @@ class GridSearchAE(BaseAccuracyEstimator): except Exception as e: self._sout(f"something went wrong for config {params}; skipping:") self._sout(f"\tException: {e}") - # traceback(e) score = None return params, score, model