From eccd818719d7a37a962b372dbd35b684642d1f1e Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 2 Nov 2023 00:28:13 +0100 Subject: [PATCH] grid search base implementation, MCAE adapted --- .vscode/launch.json | 9 +- conf.yaml | 11 +- quacc.log | 21 ++ quacc/dataset.py | 10 +- quacc/error.py | 33 ++- quacc/evaluation/method.py | 168 ++++++++------- quacc/main_test.py | 76 +++++++ quacc/method/__pycache__/base.cpython-311.pyc | Bin 0 -> 10728 bytes .../model_selection.cpython-311.pyc | Bin 0 -> 10646 bytes quacc/{estimator.py => method/base.py} | 104 ++++----- quacc/method/model_selection.py | 204 ++++++++++++++++++ quacc/old_main.py | 138 ------------ tests/test_baseline.py | 20 -- ...test_baseline.cpython-311-pytest-7.4.2.pyc | Bin 0 -> 2756 bytes tests/test_evaluation/test_baseline.py | 12 ++ .../test_base.cpython-311-pytest-7.4.2.pyc | Bin 0 -> 5919 bytes ...del_selection.cpython-311-pytest-7.4.2.pyc | Bin 0 -> 554 bytes .../test_BQAE.cpython-311-pytest-7.4.2.pyc | Bin 0 -> 5871 bytes .../test_MCAE.cpython-311-pytest-7.4.2.pyc | Bin 0 -> 543 bytes .../test_base/test_BQAE.py} | 6 +- tests/test_method/test_base/test_MCAE.py | 2 + 21 files changed, 483 insertions(+), 331 deletions(-) create mode 100644 quacc/main_test.py create mode 100644 quacc/method/__pycache__/base.cpython-311.pyc create mode 100644 quacc/method/__pycache__/model_selection.cpython-311.pyc rename quacc/{estimator.py => method/base.py} (63%) create mode 100644 quacc/method/model_selection.py delete mode 100644 quacc/old_main.py delete mode 100644 tests/test_baseline.py create mode 100644 tests/test_evaluation/__pycache__/test_baseline.cpython-311-pytest-7.4.2.pyc create mode 100644 tests/test_evaluation/test_baseline.py create mode 100644 tests/test_method/__pycache__/test_base.cpython-311-pytest-7.4.2.pyc create mode 100644 tests/test_method/__pycache__/test_model_selection.cpython-311-pytest-7.4.2.pyc create mode 100644 tests/test_method/test_base/__pycache__/test_BQAE.cpython-311-pytest-7.4.2.pyc create mode 100644 tests/test_method/test_base/__pycache__/test_MCAE.cpython-311-pytest-7.4.2.pyc rename tests/{test_estimator.py => test_method/test_base/test_BQAE.py} (95%) create mode 100644 tests/test_method/test_base/test_MCAE.py diff --git a/.vscode/launch.json b/.vscode/launch.json index f6c8bea..429433a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,22 +5,21 @@ "version": "0.2.0", "configurations": [ - { "name": "main", "type": "python", "request": "launch", "program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main.py", "console": "integratedTerminal", - "justMyCode": true + "justMyCode": false }, { - "name": "models", + "name": "main_test", "type": "python", "request": "launch", - "program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\baselines\\models.py", + "program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main_test.py", "console": "integratedTerminal", "justMyCode": true - } + }, ] } \ No newline at end of file diff --git a/conf.yaml b/conf.yaml index e4e6294..8cf0b81 100644 --- a/conf.yaml +++ b/conf.yaml @@ -5,7 +5,6 @@ debug_conf: &debug_conf DATASET_N_PREVS: 5 DATASET_PREVS: - 0.5 - - 0.1 confs: - DATASET_NAME: imdb @@ -13,17 +12,9 @@ debug_conf: &debug_conf plot_confs: debug: PLOT_ESTIMATORS: + - mul_sld_gs - ref - - atc_mc - - atc_ne PLOT_STDEV: true - debug_plus: - PLOT_ESTIMATORS: - - mul_sld_bcts - - mul_sld - - ref - - atc_mc - - atc_ne test_conf: &test_conf global: diff --git a/quacc.log b/quacc.log index 8fc15af..ffe98ee 100644 --- a/quacc.log +++ b/quacc.log @@ -1473,3 +1473,24 @@ 31/10/23 17:05:50| INFO mul_sld finished [took 29.3523s] 31/10/23 17:06:00| INFO mul_sld_bcts finished [took 39.8376s] 31/10/23 17:06:00| INFO Dataset sample 0.50 of dataset imdb_2prevs finished [took 41.2888s] +---------------------------------------------------------------------------------------------------- +31/10/23 20:19:37| INFO dataset imdb_1prevs +31/10/23 20:19:48| INFO Dataset sample 0.50 of dataset imdb_1prevs started +31/10/23 20:20:07| INFO ref finished [took 17.4125s] +---------------------------------------------------------------------------------------------------- +31/10/23 20:20:50| INFO dataset imdb_1prevs +31/10/23 20:21:01| INFO Dataset sample 0.50 of dataset imdb_1prevs started +31/10/23 20:21:19| INFO ref finished [took 17.0717s] +---------------------------------------------------------------------------------------------------- +31/10/23 20:22:05| INFO dataset imdb_1prevs +31/10/23 20:22:15| INFO Dataset sample 0.50 of dataset imdb_1prevs started +31/10/23 20:22:35| INFO ref finished [took 18.4752s] +---------------------------------------------------------------------------------------------------- +31/10/23 20:23:38| INFO dataset imdb_1prevs +31/10/23 20:23:48| INFO Dataset sample 0.50 of dataset imdb_1prevs started +31/10/23 20:24:08| INFO ref finished [took 18.3216s] +---------------------------------------------------------------------------------------------------- +01/11/23 13:07:19| INFO dataset imdb_1prevs +01/11/23 13:07:27| INFO Dataset sample 0.50 of dataset imdb_1prevs started +01/11/23 13:07:27| ERROR Evaluation over imdb_1prevs failed. Exception: 'Invalid estimator: estimator mul_sld_gs does not exist' +01/11/23 13:07:27| ERROR Failed while saving configuration imdb_debug of imdb_1prevs. Exception: cannot access local variable 'dr' where it is not associated with a value diff --git a/quacc/dataset.py b/quacc/dataset.py index c11b0c9..aff0fbb 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -71,18 +71,16 @@ class Dataset: return all_train, test - def get_raw(self, validation=True) -> DatasetSample: + def get_raw(self) -> DatasetSample: all_train, test = { "spambase": self.__spambase, "imdb": self.__imdb, "rcv1": self.__rcv1, }[self._name]() - train, val = all_train, None - if validation: - train, val = all_train.split_stratified( - train_prop=TRAIN_VAL_PROP, random_state=0 - ) + train, val = all_train.split_stratified( + train_prop=TRAIN_VAL_PROP, random_state=0 + ) return DatasetSample(train, val, test) diff --git a/quacc/error.py b/quacc/error.py index 6ed7dd4..4393d72 100644 --- a/quacc/error.py +++ b/quacc/error.py @@ -1,13 +1,10 @@ -import quapy as qp +import numpy as np def from_name(err_name): - if err_name == "f1e": - return f1e - elif err_name == "f1": - return f1 - else: - return qp.error.from_name(err_name) + assert err_name in ERROR_NAMES, f"unknown error {err_name}" + callable_error = globals()[err_name] + return callable_error # def f1(prev): @@ -36,5 +33,23 @@ def f1e(prev): return 1 - f1(prev) -def acc(prev): - return (prev[0] + prev[3]) / sum(prev) +def acc(prev: np.ndarray) -> float: + return (prev[0] + prev[3]) / np.sum(prev) + + +def accd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> np.ndarray: + vacc = np.vectorize(acc, signature="(m)->()") + a_tp = vacc(true_prevs) + a_ep = vacc(estim_prevs) + return np.abs(a_tp - a_ep) + + +def maccd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> float: + return accd(true_prevs, estim_prevs).mean() + + +ACCURACY_ERROR = {maccd} +ACCURACY_ERROR_SINGLE = {accd} +ACCURACY_ERROR_NAMES = {func.__name__ for func in ACCURACY_ERROR} +ACCURACY_ERROR_SINGLE_NAMES = {func.__name__ for func in ACCURACY_ERROR_SINGLE} +ERROR_NAMES = ACCURACY_ERROR_NAMES | ACCURACY_ERROR_SINGLE_NAMES diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index b15990a..f08bd0b 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -1,19 +1,16 @@ from functools import wraps +from typing import Callable, Union import numpy as np -import sklearn.metrics as metrics -from quapy.data import LabelledCollection -from quapy.protocol import AbstractStochasticSeededProtocol -from sklearn.base import BaseEstimator +from quapy.method.aggregative import SLD +from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol +from sklearn.linear_model import LogisticRegression -import quacc.error as error +import quacc as qc from quacc.evaluation.report import EvaluationReport +from quacc.method.model_selection import GridSearchAE -from ..estimator import ( - AccuracyEstimator, - BinaryQuantifierAccuracyEstimator, - MulticlassAccuracyEstimator, -) +from ..method.base import BQAE, MCAE, BaseAccuracyEstimator _methods = {} @@ -28,108 +25,115 @@ def method(func): return wrapper -def estimate( - estimator: AccuracyEstimator, - protocol: AbstractStochasticSeededProtocol, -): - base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], [] - for sample in protocol(): - e_sample, pred_proba = estimator.extend(sample) - estim_prev = estimator.estimate(e_sample.X, ext=True) - base_prevs.append(sample.prevalence()) - true_prevs.append(e_sample.prevalence()) - estim_prevs.append(estim_prev) - pred_probas.append(pred_proba) - labels.append(sample.y) +def evaluate( + estimator: BaseAccuracyEstimator, + protocol: AbstractProtocol, + error_metric: Union[Callable | str], +) -> float: + if isinstance(error_metric, str): + error_metric = qc.error.from_name(error_metric) - return base_prevs, true_prevs, estim_prevs, pred_probas, labels + collator_bck_ = protocol.collator + protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") + + estim_prevs, true_prevs = [], [] + for sample in protocol(): + e_sample = estimator.extend(sample) + estim_prev = estimator.estimate(e_sample.X, ext=True) + estim_prevs.append(estim_prev) + true_prevs.append(e_sample.prevalence()) + + protocol.collator = collator_bck_ + + true_prevs = np.array(true_prevs) + estim_prevs = np.array(estim_prevs) + + return error_metric(true_prevs, estim_prevs) def evaluation_report( - estimator: AccuracyEstimator, - protocol: AbstractStochasticSeededProtocol, + estimator: BaseAccuracyEstimator, + protocol: AbstractProtocol, method: str, ) -> EvaluationReport: - base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate( - estimator, protocol - ) report = EvaluationReport(name=method) - - for base_prev, true_prev, estim_prev, pred_proba, label in zip( - base_prevs, true_prevs, estim_prevs, pred_probas, labels - ): - pred = np.argmax(pred_proba, axis=-1) - acc_score = error.acc(estim_prev) - f1_score = error.f1(estim_prev) + for sample in protocol(): + e_sample = estimator.extend(sample) + estim_prev = estimator.estimate(e_sample.X, ext=True) + acc_score = qc.error.acc(estim_prev) + f1_score = qc.error.f1(estim_prev) report.append_row( - base_prev, + sample.prevalence(), acc_score=acc_score, - acc=abs(metrics.accuracy_score(label, pred) - acc_score), + acc=abs(qc.error.acc(e_sample.prevalence()) - acc_score), f1_score=f1_score, - f1=abs(error.f1(true_prev) - f1_score), + f1=abs(qc.error.f1(e_sample.prevalence()) - f1_score), ) - report.fit_score = estimator.fit_score - return report -def evaluate( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, - method: str, - q_model: str, - **kwargs, -): - estimator: AccuracyEstimator = { - "bin": BinaryQuantifierAccuracyEstimator, - "mul": MulticlassAccuracyEstimator, - }[method](c_model, q_model=q_model.upper(), **kwargs) - estimator.fit(validation) - _method = f"{method}_{q_model}" - if "recalib" in kwargs: - _method += f"_{kwargs['recalib']}" - if ("gs", True) in kwargs.items(): - _method += "_gs" - return evaluation_report(estimator, protocol, _method) - - @method def bin_sld(c_model, validation, protocol) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "sld") + est = BQAE(c_model, SLD(LogisticRegression())) + est.fit(validation) + return evaluation_report( + estimator=est, + protocol=protocol, + method="bin_sld", + ) @method def mul_sld(c_model, validation, protocol) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "sld") + est = MCAE(c_model, SLD(LogisticRegression())) + est.fit(validation) + return evaluation_report( + estimator=est, + protocor=protocol, + method="mul_sld", + ) @method def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "sld", recalib="bcts") + est = BQAE(c_model, SLD(LogisticRegression(), recalib="bcts")) + est.fit(validation) + return evaluation_report( + estimator=est, + protocol=protocol, + method="bin_sld_bcts", + ) @method def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "sld", recalib="bcts") - - -@method -def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "sld", gs=True) + est = MCAE(c_model, SLD(LogisticRegression(), recalib="bcts")) + est.fit(validation) + return evaluation_report( + estimator=est, + protocol=protocol, + method="mul_sld_bcts", + ) @method def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "sld", gs=True) - - -@method -def bin_cc(c_model, validation, protocol) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "cc") - - -@method -def mul_cc(c_model, validation, protocol) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "cc") + v_train, v_val = validation.split_stratified(0.6, random_state=0) + model = SLD(LogisticRegression()) + est = GridSearchAE( + model=model, + param_grid={ + "q__classifier__C": np.logspace(-3, 3, 7), + "q__classifier__class_weight": [None, "balanced"], + "q__recalib": [None, "bcts", "vs"], + }, + refit=False, + protocol=UPP(v_val, repeats=100), + verbose=True, + ).fit(v_train) + return evaluation_report( + estimator=est, + protocol=protocol, + method="mul_sld_gs", + ) diff --git a/quacc/main_test.py b/quacc/main_test.py new file mode 100644 index 0000000..7239908 --- /dev/null +++ b/quacc/main_test.py @@ -0,0 +1,76 @@ +from copy import deepcopy +from time import time + +import numpy as np +import win11toast +from quapy.method.aggregative import SLD +from quapy.protocol import APP, UPP +from sklearn.linear_model import LogisticRegression + +from quacc.dataset import Dataset +from quacc.error import acc +from quacc.evaluation.report import CompReport, EvaluationReport +from quacc.method.base import MultiClassAccuracyEstimator +from quacc.method.model_selection import GridSearchAE + + +def test_gs(): + d = Dataset(name="rcv1", target="CCAT", n_prevalences=1).get_raw() + + classifier = LogisticRegression() + classifier.fit(*d.train.Xy) + + quantifier = SLD(LogisticRegression()) + estimator = MultiClassAccuracyEstimator(classifier, quantifier) + estimator.fit(d.validation) + + v_train, v_val = d.validation.split_stratified(0.6, random_state=0) + gs_protocol = UPP(v_val, sample_size=1000, repeats=100) + gs_estimator = GridSearchAE( + model=deepcopy(estimator), + param_grid={ + "q__classifier__C": np.logspace(-3, 3, 7), + "q__classifier__class_weight": [None, "balanced"], + "q__recalib": [None, "bcts", "vs"], + }, + refit=False, + protocol=gs_protocol, + verbose=True, + ).fit(v_train) + + tstart = time() + erb, ergs = EvaluationReport("base"), EvaluationReport("gs") + protocol = APP( + d.test, + sample_size=1000, + n_prevalences=21, + repeats=100, + return_type="labelled_collection", + ) + for sample in protocol(): + e_sample = gs_estimator.extend(sample) + estim_prev_b = estimator.estimate(e_sample.X, ext=True) + estim_prev_gs = gs_estimator.estimate(e_sample.X, ext=True) + erb.append_row( + sample.prevalence(), + acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_b)), + ) + ergs.append_row( + sample.prevalence(), + acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_gs)), + ) + + cr = CompReport( + [erb, ergs], + "test", + train_prev=d.train_prev, + valid_prev=d.validation_prev, + ) + + print(cr.table()) + print(f"[took {time() - tstart:.3f}s]") + win11toast.notify("Test", "completed") + + +if __name__ == "__main__": + test_gs() diff --git a/quacc/method/__pycache__/base.cpython-311.pyc b/quacc/method/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..220398abe6bacff0b9d6c5093f0feed9f3899f81 GIT binary patch literal 10728 zcmcgyYit`=cAg<;$l*(*CDNkak}cVyW65u2{Z=elPGVb$UrCl#mZ3OfnF=4u%*Y!P z6e{c*U1lL{wF?JUTQ5Q^Qm<2|fC>bKfAoic6pQ{UIKl)X1{L695w!V7*GaJrP^9PF z;S4!?B<=$3)#%Kfd(XXd=bm%V`R=*=tha`;BHSj(_EAzbxg{Yc1PW5Pt=q4M!lqrP5IId(FT%t zr2OeXG(hss)V6eEw2|bwR8u+_4U)Vo6-sZ9ZYOznsyQ8whH1)1ouhcqZHo5_-urfn zdIHO!+M=C2eT9l_`x3hxkI-5}d{UCd zctTDK@{KGHC5;#2lR_#b@FQ8E2|1a~MC_VvWJI%HxiG3VoQX@qrP+8!PF_z6A~f)v z6O;TEAuc9vT!KgYwbx&Vn-BTUN^&wCm$S&$cp*ELgy+O%VM-Jv3HznBZ3!_eNwJ&p zRBT2R_+&yx*0b-+LWZv&LZUheGzm!PK!ld3Zzh84FQgQfZoHEV*o;iI9Q@YSO=JEeo6bNc% zo{B+EPkE|H8~OTsp1u#|56wE7%G1!!X7J=~@MOM)r>;<+(BlzCb6ro$F)5K11+71} z>6l}�??wZmc$PF#x@a+kk*ZFp`kEUMrFC9o^mC-@KSOHgQc7L}}szbe_p&yWhyB zW|9+9c2-OX6S5#BCtv~+i3wfhCMIFR24`+*Tr8H%z@TEe&P~TNXw)~to1_Ck=BbLC zYV9mA3k_w?yUKN~a9u^W!o8q!FO;|!%B{Un<*(if{(}D*G@=sb+`t2>5}EUNAaj?J z&3U9?9(npc@IEy8r4HGgNeXO5o}M=Arp=na0!#60yHzSb0R;v1V8%6HRn0-&vy<^$ z_`s-o?A2HxKW>DUw)S&ait zkh6(wN@E2gb5;~!9cxZ7vqE-O)*PAGbaqnGoHqq=GAjuYhi1#nXk02gCC$VWg6P8W z+TNQH_hH7Kjthp0^gw|{MclC|DMnr{<+`^}4U=OBww?qsPkrs6cJzMS_TcjJw&jlZ{3+F zhg!?cE#>V%+bYf8fV)EJ#N{RmGp;c(4Jpkb3i7O&vDi@*AyJ$RWbRUVo79!3rXjDV zt?X!REnC;_oJ~e4S5r+vz|PvNigbk%n}99C>SnPG@*4e?=)?Oqf%vo-j@t05`r15- z)oxttx=k=FslHGhdb?cv<`U3r(HKgHfy`57e@lg;U8l<}9rwCcTZUFzhL+{eZYnLO z)Rt3)SJ(VKuuzvrAMH{6XI1~%5_i^EfMPtEu`EDbWMm$3eV{+8e%~jQp)qSo9kNy{q&T$h`Siv2kqQ zlVDGUg?q(8g?3h)n4_#D$gsGm4C3{%?o#)`3I+G|q|?rh(O`Xnmf8lYnUQOHW)teB z9Kni>X1%4B7fOS}O3N_l80)Q(r*`1D{UgEyBdl_h`eoZ^@cym2%_@C1a3T`$Zzh~s zOwZ7|{0#JPBit)c*J*%!UW9yt1K=48V9yCg+bM9Y z2vhUv1dNHDP%7aN0DsLBvq2bhz2EU6CYlhwmPy9^<4-wUvh*oZMUHu)@fUM{_JGjVw&m%MIyv;H7$f^c_!7(yc#|Uqr zYL2sL$^iyoru;g{JjcS8$CRnb2AMo-9fz?>oOQAzJ~vSHxu&-uqSLtQ{_B?3Qr+T! z^9A%NfJ4zCv>-(s`XU>H_lYPR%^|?{M3Oa^>8Uk0@g%r$y(!^HVfjJQ7a$t>?_Pk0 z8+r{mrFVdU=(n{N9A$qaKxEfzbUEC9&%YWzvJyVBoK(WYYIwNt%35f*8v4PbmdCFs zp>Z`dUKo3#>&w7OXyCzhC3Hj$9VxiW{-#y`z>0t1!L;H(s``(XxT8cNEXx)b3Nc$b z$XJ(c4Pwv14zh-g>dNhT`zZBx*Bp~)^0u3_*s>)~Vt9HAfcmHQ`(P*@qOl-pj$Y`_ zV?2Tx|2Lydv_TGlFcW`2DFKo~f2eT;GYe8oybO=JJ9U6G1W^VSvcthWv*yzUQsvgX zvN&r>RAWH4l7zZfC#E;yF?Zy<7Uz~K@)k6et^fhKp`-x37JL9#@8s3yy#;SM*o|U* z8T>5L!YGJj>14^@57$~p*S#MWbIXSx?R*q^e5urNPU$$OcAQ&u{{iq@iEGi#>?_cD z#7E$Wh_*pIk7OPEz}c*!3=tiWjxg}kG+%X7n1GN@EcOYl`w9eE!@8c1BR>*^EYjZq z0Y6~;RctJKfFKXJSCI$isI5qYgAD>ks8DcUPXbQ%ScL+jBLd@z6zQbV!L}}opXIP+ z>vIr&mT>T!BL-Z&7R2BNd|3f(qQA zp0s}SWa4Qu1%8fy|ISr=fHsme9 zpRRht(DUX7^%%^v6$R#L8FiqVto~6vQtwH@+?ukSq9Mp+%Ir7vWe8RCwj})v8n&e# zqc^G!{cPNp*2^-}ZVhqCXichV-I$ytkZ5WFjLOVUo7w<<*X3&*MH6-#z2n`gjRX~# z@8|9BlKEMr#T`==I2Cj*|Ai=E+Wb!OqdZ&m*7s=Sd4^{{by&Y(0{WYC%sBx^ARO3` z@=hy`nB(M*8k3B>z-p2)I8sg4Z?N`oola7D?qigWEAJ|z@7j>^4r?1&$xIPLnj2Ey zHI3l}k~XxVrh`V^R;|~!S$Ob#-vl#`ds>n{bz0~2JHBfjDnnHLM~v|si@e+90b?Yf zJLnQtFv;mR4Vm!lV(=iBz$U=nH%_ds(NUP>fN(SoHV-vV)L}F`g zZ{4cxYKSS;Tp0uRYCdTO;>@whTjV@OoW^Eol(m3ffu}{}=!e(C9k}|*@I+jt1c+#k zIQW4KuX!c`kLiaDG4Wlj-h+fhLzuJ>&xokQA{nwnu?7Dg8?6fUg!nwQwmyACM6gQ^cgHngj!+}4eCPaurfuD0w0ztd1|YQA$~ zwP|3bX+UY(r#9^?a8Epb)zeW7t$21Ro?YdZw!*9Bww)EK0ebuK*Zs}UNyZ5*L-}skVrF%&29$I{5tuwL~Zu`7_ zP-)+y86z#{s4FpxTOa8EC%i`ITqs zoD%3&1HGj{FG}CH8t7dK^cLS#0t0GbphT{Ui)!eHUMn`EOMUlcI(E=3T-5md>|-!y zh*M!3zW;;g2%T7s-<;)uqXr$oecRU=Zzcn0ROUf2cL#uFeaGvO5BXFMANf|wpL_1hg2 zAk;a)XLI%sm?G3{Oo-7yIKWps-v6a-p%akx8Q_NPZvX?%c*cG!c0))3s*bRlZE(NF zkQfA$$%<)W;fU2VH?=xS{>FVG1t;x^Ym_C@;NPv4eS8_4cfIdy2gemp## zvk`xYJh)GX&3`5Vr$v$;!(nm$B+a9T3rPG%Bmo3CU=wb}#O#kXhLob$e5d9>>SH_4 zu?sOK_Cx4J`cEJ*iBB-P!??OIWZ+#5A6y9^Tt2FVPpaXQn+10;ywh12`@-M&;o&<+ zm#*IXk$74$Pyc$0Lfv@0LWWnmxPd$;@7m>XW-dMi& zX!IX0Dv?*!$g5z#o83=QoL?a+c;GO7dl&+Z;kKneDTN~#x9BIN@Nzi}!3^35k3bfg za-elJ(6bT%c`Je4YG8MXTxeYT;rK7mxR6_d1ExOJ-&f-LbjxRMVqhL{r4!5dFOd0} z4TF%=kkb=S*Ontszf0PQXX{ViYtg9sh*P~)E7;&fz;>I<+qVcduKU%5)lOR1z4hq| zwfRVRJ8z$|t>=Zv7C~zkvw_t^o;mZ;Ls1)s!N+Y zk&34$`S_{a==U$|^){q_f$P2eE%uL=5;PNzYp?!Yc)ofnH35Ziddx~U;$VGnorymO zqA_S1@Zfkxd;?jkhY>$P1{|%9yHIG|-S1S_%RD^g4*u`Q`w-emlR#i}WsiU1*kXL) zQg&WecV1RHuc)0@ zl%}g{)76sq>KE8@Nh*#kO)E_iwJB2aMs(|BCrEZ|TqKb<&!&+#+-T^Kn%J)qe0&+n zJdy+w9*BgFVpvC<*Rc>u94r1D2z-$7SFt(Sc8q_u7k%WmHEHZ;2kTDY;Hg4c$yGYY z?y68$vX^1^;KAV5gnT_2VWdm^GZ>#{k54AZ0j9=gX45mbbW6*EQ-b^ua^SC+W^N7g zaXC(=LksH#`nF{-J_Wz>nu4tY{1O2ng1$#UkEw;Kj2I(}!LNw)<8TqdgVvze8NYMT zf3M(`-c7;p%QAzL@KYfY|I~s;Q7V~%Dm}tPwjf%wQAL7qbuU0;PKKRG=))!gBqy&V z3W^fa>}Os&t1%Zxke<2p(pjKK4&3GPM+6d(LgC zrMAt??0#*#efmDnz4zR6&+F=Ufq;*JXWhTQEdS>khWQg_YLBa$c(MSAdyK?LY>FAf zuOr2#9b*m_(_D&6b7LG$J5$cIYs{5)kGW~ymGY$dF`lN~DR0_0=1cp>{4~#|gs}j{ z*^~|@09QIR7Nj+vRC78!7N%)F)sl{kMOencoMa^L4My_GkuNxg`3nBcXRK9X&oWW} ze_?hw${M^R%TtN$)C{DZYBDWH9R`0qo=U~XQ?lV2%_OrKs0zOtAD2@pSvsBtm{76Q zVT6v2D=LX6)H5WjW)s=INO{8Wu{E}0Nmv&Whq!{Ewr=bjh~uz~(Ie_5P0-F*A8RCHESAZVbUc;J%P~qth2Zc6M>SkTo=B>O z=bB8$vx;nVE$fdznUd-CZ84G%%-N2=ap=ZB;@m|tR&B0&BYUm^J!U~%u46!?#Eyh z^ZfKp%@}+vmdqs8SS;V#xXivDi_Y7=a@V@YUArH3?S8PW)HR@Y z4M1VA+#$jzP;TwQCKWf+u^tjanHL`OU61&#`TZq+gU)Zz_zjD}&iSqX(Yvn{+^+}s zYy2sE51uOVr*tz-;S(5?Nm#oF8RQ4bjc-B%xRJRGU-KawREw(^*k!1*dS13_J^@)+ zU_W6#Tfsxf1(XxG;b-i-dG12L71MG`0hM9tb+i%{pgY1t)1LaFXK+B{tzDuHXXlI&QiPjw3MW?Hxt;1S>gj@GuW+ zLmFqZd>-~Z98K4-;Q*SmR}BGpIin^gk}`q4u*3^s!9hNEwn!Co=gYer<*;U#i9KBeGL zaQv>*eSYRUA|(@QBAcE%n%`^R>=kmN??@_}h^LgJef3Ve0GqNCz|-rD7926(f8O_S z_rv)7)rFJuh55o<`+W1=uDPxU%@5wNn$XRUa)$GYJfj$15JYMcHKO54s&ZP1dQ84_ zn-a?4pc&Z)!x%0Q|H-uBQF7xkEZl(;sLPa>O!f`0lydp?e`R@5#Ps*(Qp({-(Nk{e zym$ClhwlvB9jGvz_Yeie<7KhucY%Kj+?Pt?UR~T<64n-v6?1c!%E1Wmb>Pr^^L*kT zI_5jd>kr<)q^&;)@0@3zyCXbZsu2jt@b!UzJMztue?9ugqmU^DM|7y;NB+W@=mma3 z1*XCbYtLnihw8!t_{2-qoK=TnL2XcqfCJ@q^|7lMvUzp<3zk}|^}}yr*|`GsXCF0C zZ(U&F6aqij)!C}^Y*2-gWA(XQwq}nf)qb)+V58hH$~R!t6zGJjxs3eY6le=s5~zu$ay%swgMbNtuY_vKSY!J=$yHL@q;ZGcik`7;H0<$WugCWH`B*OFrl!qU;p2 z6Jk75#cB0PX2hu1a9mALt!FeHJAQn0=T{GvqmT`JR-5!{M@@oJny+a?r zeD^1MWV`0ax7@j|IC%SX5&jkf?enhBlv1Ea5A;+R-rKa;+5Kt9GWEmow+L!G+uZU$ z7Z;AY;;{>PJFZy=%mt+^eP+2N%&l$}N-X5uwc19@&yI7}S@eVAaopflG@z=_Wuyar z3XWeggUqr%uA9sP^{;BJYF~-7&yi({Opyiq92J{u1s-?74GQ(AR84wj`B`s)|M`os zDr??9_ZRra{1rM7X~oj|>Ixpn)w~X-;Dww}n-T4A2^AG>f$CH6*yl>s@JR{QR`*W+ zYWvvpDO+&nfkv~w#`#|ls%vZGgZ1UM&)@^tv**wiR(RQ3Ay{WNsz}~XoRY8Lll<@% zK3ldf7J1eZcVt@KP@e&K0UPVNWny#axw8q_=S@ejD83M`W2~ocvITAdO^+3iowxE( zzZ`pJt%^#s`bbTUbH0qmK7D{H!OxcQA5s)dqfq|8Uj^R|PrFFKwoGQ^KJi3SRYcR? zSHx)$(ISXu5locvWG0R_f`YPl62v&6)`6^o0|ByF5>3BA5m82GGqRWlIg(E1QEZzI zhY7LT)CDjKNhbPyR&_tcDAq(!KSfo`0e%(miqn(H#H6TZ#q1Qg+2BXfR?8+-M@eKg zM>735y9ujw*k=BrLZjy2X-J^?cL~kE^|NfFRBH4OZh}c@=L-AY+;Q9%&sI%Bhv7sU z&*1XOsVGNq#|-XTJZ1Pz;|k3vML{vwU=^m4O4TatwJ8(u4jarstJN9ktsDXXW?+SJ zH2dnHykqAt20w0@dv~Fy)Vx`5-dqy4=)#r?XKO`ef4&3UNiYO~8>pjK-B8$a8?{gh zQEH)N_}EdUR>v%R)3E?0D<3qU?bHJLbk^v@uJCfJbdiBb;6ow7$ZiDaDv_4}RE_BC zrV&j}LQ(#2qSIU084;qn-2969K0Gkj`#9YDDBQd7=KZ%z;RAa3fK_QKMGuCR&>`!~ zWE+pdY9~RFI52A&%!%+N&0t5T&fFh?%S^Ceu?Xnf6Dld7lWbs!S&02|3zkIFZ53P(mLN z&2Ynx07^4f5unqL(U~hT!;eccb1{PI*YJWXL0};%#qhwuQZAtql&Xe6n^t$Z+R=mJvYg?f88eRK#IPSZ0T>k&>K89sLe2Dg2y$0GwEVoj zh#iiC<^A+2aNVDqnD*}ZeV@MgxO?ZL?wzIX-Fo-#QrjNAZBNm^*s*5*+^4a}Yxh1{ zySKDVspEj&aiADjYH7baP~?}^LQq8KyUN}4+wpxUa>sqwU368Pw)0HS|Ki>w z`rg5>)7sv1^tk_klHo@L(B8**10VTr`Lvc7N_?Nr_i23JQte2X?Em#Bd`r!*=*_Pb z-AldO&=Q4s1Avks7EjC#&F=@NOACtSmUg|RM{C)BKleb@51xFq^Q5-(WO>Vuxj{YB zTi)7--%VD2&%rsT9@_ZZy$jd=@$lz|OQD@wXy?+N!>g2jA8fgK?e?{K*Ug_5olE|p z?(f$8J&WOW+WOZ^;WK*pj3%5};ssm@7`lOt4|?@I$I%bmNWpjK^w%zEf?O{tw{~j6 z1~7xGEoco`!d{>Z0CX~aKpY`-yC0VtuEv^PDg}?}!DAK169|_>k$b*h`LyoirO==r z8q`9AP^5>%$D!Uwq27hJAMlTNoq4qDOljAMzH6kkX`~c7tB1~N^!>iA7iQY}7Q49V zn7-*W%mmUf3k$+z+QTrJ_ApGQISi8tMT&mYHEpzp@i<`q$u?jvdo*hnqNH_ueAcO2 zmL6CU(4`)~oOOBa&S`_~Y}aYCZU<9v*KI&406`M=bM}`k3+2y=wJviH?XlLnBMFgP zuHgJVYmfBLdY(V;f{R*_JVbes4l5orP6y#6kK_dF&sBiP5`bzg_Kb1{N?VXPF{szH8)uj*6YH0O<2DeS*J&~+;=_XzsqTn(Nbho zkBk2l~SprRt`!I1d{__b^$A@Kv}2tdRD04SI$ zcnY%(bwS<>@wXbN;?&yDJbV*?`Idw?bm0w6c%vL{DV_v%_;fMYQDL0k&N{Fjl z*8ZCLmV{wl7}kVgt9IMMRjqv+yd`0~E^OC??Mso4;!yn$u~)o+8ovS)YMaZ=%UZBU z<9n1|p!;u*z4-Dv=9_iC!2tKo*6!mR_bnHI^tYkCgFN>~-UaDl!-=Ma;Wk}}M5FLU z#-qZw2nl;Aq=1)1Ad>5mHhOX7unXP~;EG}OXt}-J>cJF@JpX=~Q4L4iIZQ@jKF_+U zk?t89SA*EwxC|!@f@?_?ZZLfiPm^L`l8wi~AwetN@KKY@yxeok^nw$rsKK3tdXffE z^-)YVCs`=Tx7dswv)^yA=^%jX%wn)@PMyC{3U1ef+pE_!FVwGT%E6Y8-oN#JDJbee z@o{j=qu`eN&WBxEa7!sTs0Rl%esHDM`+rLK_$fHY)q7u%h@ zor5d8OVkpl4uI)pO=NT6JjW-%TBh4oYtoN$vi~CN*~O|qzS2{!-2f`!-^S6DM=1(- zf;56h2eMjNzHficBaVw1DNaay2Bo$#8K08Hi(511EtRZ9FPh%}inV~9kGI>f9B?w= zaxhCII_2!D3$unrIXON<+br`g@y0YED@vc$mU1yk0TThca==$rJ|!6_0zqD#L6t?p z)2LpADGA0A8?jml#}AhYalvey4gY>Wmy*ySN6)c(I<*65>P^nUbVix!NQf&+SR zK;s8YW^D+H!mk(@2E$WN(7srH{Y`(y;K={xtPBAc)T31YzCaa(8B_r;5~AlRIJql# zzTT`~jEfjy02LW0Gv=mYOo8IGn$f7a8KH*fSD$T2Z zI={vkI2nf(=G$%ZpstMn0F?^H+8{jPW4brteQy!|Oo}0uY7bJ6%2m|9cB2;L8=Jp3 zLc=-c;u=<63rVd`K5^?sN(lOkg^SS@0Vn(gNxwLn$tkjAt=u{%NT`OP8c|t|Ne<>) zqmdKwHB!g$p~Fh2pe!R&z2aX01NAI{NT9JIx8K+Jbt`$w5`e>$LX9D+m70fm8XW<+ zhiY&Hq-DlA-reN?WUZ)?(ElfCwL?ty$$dL08)7(+#_>QGRU;S;R--lMs=GKDS8Jg} z=#86UMh|PZK8Zy7AWdt7G$e*+BrU0|tNpzDw;qOIwxFz{UtRCb%Ix58N9 zZPwu$0&EDi?iZ?wPT&05IHMbWx+VBBqOwII3aR}q$wi-=lK`rcO@&x+Jh6`#E3Ccf4 zJ5Bf1U8O#>JhBq~`I_N2GZc9uVIoH+F`CV8W_8~<{Qk}4KFSnZp+%ihp7BRa20ltY zgbF-F3JVZ|A1qsDLR#%DGo70KEi?OkHRTJ0?}JGAG# iWji)NUS{@c_P64mb+E93t3$12>kqW*$LBPlB>EplVI&m* literal 0 HcmV?d00001 diff --git a/quacc/estimator.py b/quacc/method/base.py similarity index 63% rename from quacc/estimator.py rename to quacc/method/base.py index 216b8a1..c36636b 100644 --- a/quacc/estimator.py +++ b/quacc/method/base.py @@ -4,7 +4,7 @@ from abc import abstractmethod import numpy as np import quapy as qp from quapy.data import LabelledCollection -from quapy.method.aggregative import CC, SLD +from quapy.method.aggregative import CC, SLD, BaseQuantifier from quapy.model_selection import GridSearchQ from quapy.protocol import UPP from sklearn.base import BaseEstimator @@ -14,9 +14,22 @@ from sklearn.model_selection import cross_val_predict from quacc.data import ExtendedCollection -class AccuracyEstimator: - def __init__(self): +class BaseAccuracyEstimator(BaseQuantifier): + def __init__( + self, + classifier: BaseEstimator, + quantifier: BaseQuantifier, + ): self.fit_score = None + self.__check_classifier(classifier) + self.classifier = classifier + self.quantifier = quantifier + + def __check_classifier(self, classifier): + if not hasattr(classifier, "predict_proba"): + raise ValueError( + f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities." + ) def _gs_params(self, t_val: LabelledCollection): return { @@ -33,85 +46,55 @@ class AccuracyEstimator: "verbose": True, } - def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection: + def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: if not pred_proba: - pred_proba = self.c_model.predict_proba(base.X) - return ExtendedCollection.extend_collection(base, pred_proba), pred_proba + pred_proba = self.classifier.predict_proba(coll.X) + return ExtendedCollection.extend_collection(coll, pred_proba) @abstractmethod def fit(self, train: LabelledCollection | ExtendedCollection): ... @abstractmethod - def estimate(self, instances, ext=False): + def estimate(self, instances, ext=False) -> np.ndarray: ... -AE = AccuracyEstimator +class MultiClassAccuracyEstimator(BaseAccuracyEstimator): + def __init__( + self, + classifier: BaseEstimator, + quantifier: BaseQuantifier, + ): + super().__init__(classifier, quantifier) + def fit(self, train: LabelledCollection): + pred_probs = self.classifier.predict_proba(train.X) + self.e_train = ExtendedCollection.extend_collection(train, pred_probs) -class MulticlassAccuracyEstimator(AccuracyEstimator): - def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None): - super().__init__() - self.c_model = c_model - self._q_model_name = q_model.upper() - self.e_train = None - self.gs = gs - self.recalib = recalib + self.quantifier.fit(self.e_train) - def fit(self, train: LabelledCollection | ExtendedCollection): - # check if model is fit - # self.model.fit(*train.Xy) - if isinstance(train, LabelledCollection): - pred_prob_train = cross_val_predict( - self.c_model, *train.Xy, method="predict_proba" - ) - self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train) - else: - self.e_train = train + return self - if self._q_model_name == "SLD": - if self.gs: - t_train, t_val = self.e_train.split_stratified(0.6, random_state=0) - gs_params = self._gs_params(t_val) - self.q_model = GridSearchQ( - SLD(LogisticRegression()), - **gs_params, - ) - self.q_model.fit(t_train) - self.fit_score = self.q_model.best_score_ - else: - self.q_model = SLD(LogisticRegression(), recalib=self.recalib) - self.q_model.fit(self.e_train) - elif self._q_model_name == "CC": - self.q_model = CC(LogisticRegression()) - self.q_model.fit(self.e_train) - - def estimate(self, instances, ext=False): + def estimate(self, instances, ext=False) -> np.ndarray: + e_inst = instances if not ext: - pred_prob = self.c_model.predict_proba(instances) + pred_prob = self.classifier.predict_proba(instances) e_inst = ExtendedCollection.extend_instances(instances, pred_prob) - else: - e_inst = instances - estim_prev = self.q_model.quantify(e_inst) + estim_prev = self.quantifier.quantify(e_inst) + return self._check_prevalence_classes(estim_prev) - return self._check_prevalence_classes( - self.e_train.classes_, self.q_model, estim_prev - ) - - def _check_prevalence_classes(self, true_classes, q_model, estim_prev): - if isinstance(q_model, GridSearchQ): - estim_classes = q_model.best_model().classes_ - else: - estim_classes = q_model.classes_ + def _check_prevalence_classes(self, estim_prev) -> np.ndarray: + estim_classes = self.quantifier.classes_ + true_classes = self.e_train.classes_ for _cls in true_classes: if _cls not in estim_classes: estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0) return estim_prev -class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): +class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None): super().__init__() self.c_model = c_model @@ -190,3 +173,8 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst)))) else: return np.asarray([0.0, 0.0]) + + +BAE = BaseAccuracyEstimator +MCAE = MultiClassAccuracyEstimator +BQAE = BinaryQuantifierAccuracyEstimator diff --git a/quacc/method/model_selection.py b/quacc/method/model_selection.py new file mode 100644 index 0000000..a80d5d9 --- /dev/null +++ b/quacc/method/model_selection.py @@ -0,0 +1,204 @@ +import itertools +from copy import deepcopy +from time import time +from typing import Callable, Union + +from quapy.data import LabelledCollection +from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol + +import quacc as qc +import quacc.evaluation.method as evaluation +from quacc.data import ExtendedCollection +from quacc.method.base import BaseAccuracyEstimator + + +class GridSearchAE(BaseAccuracyEstimator): + def __init__( + self, + model: BaseAccuracyEstimator, + param_grid: dict, + protocol: AbstractProtocol, + error: Union[Callable, str] = qc.error.maccd, + refit=True, + # timeout=-1, + # n_jobs=None, + verbose=False, + ): + self.model = model + self.param_grid = self.__normalize_params(param_grid) + self.protocol = protocol + self.refit = refit + # self.timeout = timeout + # self.n_jobs = qp._get_njobs(n_jobs) + self.verbose = verbose + self.__check_error(error) + assert isinstance(protocol, AbstractProtocol), "unknown protocol" + + def _sout(self, msg): + if self.verbose: + print(f"[{self.__class__.__name__}]: {msg}") + + def __normalize_params(self, params): + __remap = {} + for key in params.keys(): + k, delim, sub_key = key.partition("__") + if delim and k == "q": + __remap[key] = f"quantifier__{sub_key}" + + return {(__remap[k] if k in __remap else k): v for k, v in params.items()} + + def __check_error(self, error): + if error in qc.error.ACCURACY_ERROR: + self.error = error + elif isinstance(error, str): + self.error = qc.error.from_name(error) + elif hasattr(error, "__call__"): + self.error = error + else: + raise ValueError( + f"unexpected error type; must either be a callable function or a str representing\n" + f"the name of an error function in {qc.error.ACCURACY_ERROR_NAMES}" + ) + + def fit(self, training: LabelledCollection): + """Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing + the error metric. + + :param training: the training set on which to optimize the hyperparameters + :return: self + """ + params_keys = list(self.param_grid.keys()) + params_values = list(self.param_grid.values()) + + protocol = self.protocol + + self.param_scores_ = {} + self.best_score_ = None + + tinit = time() + + hyper = [ + dict(zip(params_keys, val)) for val in itertools.product(*params_values) + ] + + # self._sout(f"starting model selection with {self.n_jobs =}") + self._sout("starting model selection") + + scores = [self.__params_eval(params, training) for params in hyper] + + for params, score, model in scores: + if score is not None: + if self.best_score_ is None or score < self.best_score_: + self.best_score_ = score + self.best_params_ = params + self.best_model_ = model + self.param_scores_[str(params)] = score + else: + self.param_scores_[str(params)] = "timeout" + + tend = time() - tinit + + if self.best_score_ is None: + raise TimeoutError("no combination of hyperparameters seem to work") + + self._sout( + f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " + f"[took {tend:.4f}s]" + ) + + if self.refit: + if isinstance(protocol, OnLabelledCollectionProtocol): + self._sout("refitting on the whole development set") + self.best_model_.fit(training + protocol.get_labelled_collection()) + else: + raise RuntimeWarning( + f'"refit" was requested, but the protocol does not ' + f"implement the {OnLabelledCollectionProtocol.__name__} interface" + ) + + return self + + def __params_eval(self, params, training): + protocol = self.protocol + error = self.error + + # if self.timeout > 0: + + # def handler(signum, frame): + # raise TimeoutError() + + # signal.signal(signal.SIGALRM, handler) + + tinit = time() + + # if self.timeout > 0: + # signal.alarm(self.timeout) + + try: + model = deepcopy(self.model) + # overrides default parameters with the parameters being explored at this iteration + model.set_params(**params) + model.fit(training) + score = evaluation.evaluate(model, protocol=protocol, error_metric=error) + + ttime = time() - tinit + self._sout( + f"hyperparams={params}\t got score {score:.5f} [took {ttime:.4f}s]" + ) + + # if self.timeout > 0: + # signal.alarm(0) + # except TimeoutError: + # self._sout(f"timeout ({self.timeout}s) reached for config {params}") + # score = None + except ValueError as e: + self._sout(f"the combination of hyperparameters {params} is invalid") + raise e + except Exception as e: + self._sout(f"something went wrong for config {params}; skipping:") + self._sout(f"\tException: {e}") + # traceback(e) + score = None + + return params, score, model + + def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: + assert hasattr(self, "best_model_"), "quantify called before fit" + return self.best_model().extend(coll, pred_proba=pred_proba) + + def estimate(self, instances, ext=False): + """Estimate class prevalence values using the best model found after calling the :meth:`fit` method. + + :param instances: sample contanining the instances + :return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found + by the model selection process. + """ + + assert hasattr(self, "best_model_"), "quantify called before fit" + return self.best_model().estimate(instances, ext=ext) + + def set_params(self, **parameters): + """Sets the hyper-parameters to explore. + + :param parameters: a dictionary with keys the parameter names and values the list of values to explore + """ + self.param_grid = parameters + + def get_params(self, deep=True): + """Returns the dictionary of hyper-parameters to explore (`param_grid`) + + :param deep: Unused + :return: the dictionary `param_grid` + """ + return self.param_grid + + def best_model(self): + """ + Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination + of hyper-parameters that minimized the error function. + + :return: a trained quantifier + """ + if hasattr(self, "best_model_"): + return self.best_model_ + raise ValueError("best_model called before fit") diff --git a/quacc/old_main.py b/quacc/old_main.py deleted file mode 100644 index 00ea263..0000000 --- a/quacc/old_main.py +++ /dev/null @@ -1,138 +0,0 @@ -import numpy as np -import scipy as sp -import quapy as qp -from quapy.data import LabelledCollection -from quapy.method.aggregative import SLD -from quapy.protocol import APP, AbstractStochasticSeededProtocol -from sklearn.linear_model import LogisticRegression -from sklearn.model_selection import cross_val_predict - -from .data import get_dataset - -# Extended classes -# -# 0 ~ True 0 -# 1 ~ False 1 -# 2 ~ False 0 -# 3 ~ True 1 -# _____________________ -# | | | -# | True 0 | False 1 | -# |__________|__________| -# | | | -# | False 0 | True 1 | -# |__________|__________| -# -def get_ex_class(classes, true_class, pred_class): - return true_class * classes + pred_class - - -def extend_collection(coll, pred_prob): - n_classes = coll.n_classes - - # n_X = [ X | predicted probs. ] - if isinstance(coll.X, sp.csr_matrix): - pred_prob_csr = sp.csr_matrix(pred_prob) - n_x = sp.hstack([coll.X, pred_prob_csr]) - elif isinstance(coll.X, np.ndarray): - n_x = np.concatenate((coll.X, pred_prob), axis=1) - else: - raise ValueError("Unsupported matrix format") - - # n_y = (exptected y, predicted y) - n_y = [] - for i, true_class in enumerate(coll.y): - pred_class = pred_prob[i].argmax(axis=0) - n_y.append(get_ex_class(n_classes, true_class, pred_class)) - - return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)]) - - -def qf1e_binary(prev): - recall = prev[0] / (prev[0] + prev[1]) - precision = prev[0] / (prev[0] + prev[2]) - - return 1 - 2 * (precision * recall) / (precision + recall) - - -def compute_errors(true_prev, estim_prev, n_instances): - errors = {} - _eps = 1 / (2 * n_instances) - errors = { - "mae": qp.error.mae(true_prev, estim_prev), - "rae": qp.error.rae(true_prev, estim_prev, eps=_eps), - "mrae": qp.error.mrae(true_prev, estim_prev, eps=_eps), - "kld": qp.error.kld(true_prev, estim_prev, eps=_eps), - "nkld": qp.error.nkld(true_prev, estim_prev, eps=_eps), - "true_f1e": qf1e_binary(true_prev), - "estim_f1e": qf1e_binary(estim_prev), - } - - return errors - - -def extend_and_quantify( - model, - q_model, - train, - test: LabelledCollection | AbstractStochasticSeededProtocol, -): - model.fit(*train.Xy) - - pred_prob_train = cross_val_predict(model, *train.Xy, method="predict_proba") - _train = extend_collection(train, pred_prob_train) - - q_model.fit(_train) - - def quantify_extended(test): - pred_prob_test = model.predict_proba(test.X) - _test = extend_collection(test, pred_prob_test) - _estim_prev = q_model.quantify(_test.instances) - # check that _estim_prev has all the classes and eventually fill the missing - # ones with 0 - for _cls in _test.classes_: - if _cls not in q_model.classes_: - _estim_prev = np.insert(_estim_prev, _cls, [0.0], axis=0) - print(_estim_prev) - return _test.prevalence(), _estim_prev - - if isinstance(test, LabelledCollection): - _true_prev, _estim_prev = quantify_extended(test) - _errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0]) - return ([test.prevalence()], [_true_prev], [_estim_prev], [_errors]) - - elif isinstance(test, AbstractStochasticSeededProtocol): - orig_prevs, true_prevs, estim_prevs, errors = [], [], [], [] - for index in test.samples_parameters(): - sample = test.sample(index) - _true_prev, _estim_prev = quantify_extended(sample) - - orig_prevs.append(sample.prevalence()) - true_prevs.append(_true_prev) - estim_prevs.append(_estim_prev) - errors.append(compute_errors(_true_prev, _estim_prev, sample.X.shape[0])) - - return orig_prevs, true_prevs, estim_prevs, errors - - - - -def test_1(dataset_name): - train, test = get_dataset(dataset_name) - - orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify( - LogisticRegression(), - SLD(LogisticRegression()), - train, - APP(test, n_prevalences=11, repeats=1), - ) - - for orig_prev, true_prev, estim_prev, _errors in zip( - orig_prevs, true_prevs, estim_prevs, errors - ): - print(f"original prevalence:\t{orig_prev}") - print(f"true prevalence:\t{true_prev}") - print(f"estimated prevalence:\t{estim_prev}") - for name, err in _errors.items(): - print(f"{name}={err:.3f}") - print() diff --git a/tests/test_baseline.py b/tests/test_baseline.py deleted file mode 100644 index c7a8027..0000000 --- a/tests/test_baseline.py +++ /dev/null @@ -1,20 +0,0 @@ - -from sklearn.linear_model import LogisticRegression -from quacc.evaluation.baseline import kfcv, trust_score -from quacc.dataset import get_spambase - - -class TestBaseline: - - def test_kfcv(self): - train, validation, _ = get_spambase() - c_model = LogisticRegression() - c_model.fit(train.X, train.y) - assert "f1_score" in kfcv(c_model, validation) - - def test_trust_score(self): - train, validation, test = get_spambase() - c_model = LogisticRegression() - c_model.fit(train.X, train.y) - trustscore = trust_score(c_model, train, test) - assert len(trustscore) == len(test.y) \ No newline at end of file diff --git a/tests/test_evaluation/__pycache__/test_baseline.cpython-311-pytest-7.4.2.pyc b/tests/test_evaluation/__pycache__/test_baseline.cpython-311-pytest-7.4.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aaf325b61ccd5a0856d905c06d905e6c64445c8 GIT binary patch literal 2756 zcmbsrU2oe|^j_PEout`lQz43wDAK5E7}2tB3!PwH2MCa+O29x6Rama=bnYBG-0OC0 zaij_DVS)#y38_3F)CW|A$^$PZCr+{<+4v4>$hHLTWj1onq2U#9;kmwJ zB^KPQ0JjYH!F_<6C>)U4*p_~mvk7kv$ho;#+wj;)2-Lx0E+x>AOp2r}9`A`2;gzK1 zXeG7JaPx^#HYX^x(vd+^P2?RV98GY9A}La`7M)xqN4YiB$w&Enn7$KM`r0e;k>IqQ zi}Fou=ecfE+J@|Q(hM^erNlatx1l~eX^m_(oJeE{t&|fyTaz_JCo=v{Bzoa<35HPC zGs#+{860onV_UVgr_5*k=zJn+D2uFKn#Hgl#MW&5jIr@k#vamx*f_JBewIqs<1t#f ztd?E)DZc?Hv7_2WyA-J|jwvmCx&N**p3-+1X<1KM>Khm>w{S;>I}f+g$=L{ccVvNS zv~9MQpm`thg9S67>>F+3teKwUTMlJfvu02$AP$Hm7v{yr7q3pY*K0bRbp1(=&g!)I z^&v9(cFda5=``n@)9LA$QBQOLIn+6?(<(%9Z8L;9RP5-*4&j~2RLf`uw&Uq}CPN?& zW;i&!@1S=Og}DwUSJAL&?ezC zlkg^GB_2lX@ysA`VUT!pkT@$+6W(iQGSpehJx?a<@rTyM#ZRGrbnz;@U4J8({4DU= z?jj9ViRCOpC0q_d8bgDhh}Ae#LGvTS=Y6%k&I)na__t%@`D+8;f@;dwAT)gdC+s$Q zSb6csi+_C7t<-lb^czE_#Qt34b$wL3P|8=JbT?rA5#FMV6O_rVVf zz4239*MISQQx|r%dQYo&wfdem*Nv}U3?C%qt~Mt)LYhM0gU*p6&=FRtLh3|tQW7); z2Vly4k|Tm1XPbsgy$y1n(N+Xa9S)JgIb<%76bcR~59z0SOf`>04&d#=9}I^Bgbq@H4H3l4^(*SO#zkNx`UB>(9XegX; zGri2Tq#+5Pm2AQg-CPw71ie=!&Q0Qm(8(!_sfRrB61N$rH$2BAew80=lZY`S=0_QJ zU|CfdR{XD$m@pH9*YAvK(iHJP5gV(4&*w0&-qG;inE&4~(m3b;0Y37O;{X5v literal 0 HcmV?d00001 diff --git a/tests/test_evaluation/test_baseline.py b/tests/test_evaluation/test_baseline.py new file mode 100644 index 0000000..20fac98 --- /dev/null +++ b/tests/test_evaluation/test_baseline.py @@ -0,0 +1,12 @@ +from sklearn.linear_model import LogisticRegression + +from quacc.dataset import Dataset +from quacc.evaluation.baseline import kfcv + + +class TestBaseline: + def test_kfcv(self): + spambase = Dataset("spambase", n_prevalences=1).get_raw() + c_model = LogisticRegression() + c_model.fit(spambase.train.X, spambase.train.y) + assert "f1_score" in kfcv(c_model, spambase.validation) diff --git a/tests/test_method/__pycache__/test_base.cpython-311-pytest-7.4.2.pyc b/tests/test_method/__pycache__/test_base.cpython-311-pytest-7.4.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..526d081a6d2f0586496ca4bb5fd2595c25573c42 GIT binary patch literal 5919 zcmd5ATWlN0agRqH?~V^sFQc*&MQ$2bECg1REX8(gv5I3jN>Uq1Tmxog9L{$qn)31C zy<;q@Bx)2c;E#UzqiEEk1%$M49V7_++0O*+kNzl^so4VxGzd_^s6QxJE{uG2c8^E$ zR8m$EH}US>?Ck8!?Cjj`?2vy5g+v5p=G?6IO97!TNXIGO2J#>a$a_ddDx;$$Lt9UV zRXhyJc#}TH9^sOHI>IM;+6oyl8K7exJ(vk4LlkHAaAr?(4}(1DBvQTCkm|eXLFhwx z?UIbB%-bl+-6P#oQO4?hBbQ1WW?FemOJU70(zz^9_rILZ%DDJuLC%`#g|voe6{Ub> zrFa5ZGqRb(N<$%S_B1*VES4-qqsvHoGg3plvSc#y^|`kj6Xsh;;4fGmYx z3Vi^*HyKz4ymoO*shsLhd3NOBRbgiim{WY8aqC#{!aV#Ud^o#2h1PZQ9r4=E-gq+p z3hdfn->gM$AK5#L?5VHAN5NhA=o_~v#kqro8oGvV!j}Kg+4v|as$mEld#o4EXomUN z;IKTL&KjniRWxHPk2Td8x10DFgqVVEMp;ZE(#9k$e@SdC(DKhjl)XpJc|^_lsKswuS?8gQ*yN#OdU<6Epdy@(5iaeA94yYXjdAVtLUhK zZJ(nRtR>qx-91w?xE|?k|6Vo>4Vxn`jOL5+$*3Xz*uh4QQR0L{1EfJ3 zj~X!~AtIvq8}M!=e5M9-n95f#g0 zwncH;u{_zlB{YO`*bbihyYX4q0AU-Z+JE3MZp|n79O! zKyC?HNzsWX49l07P35A+nntQR!gM@WG- zw^+QhN-Ow-d{MekNb4pTZgF}}k#)lofocMeBxCDal5sIt&{YX*c`T*$+yz;evT{bV z1j&#WG(xt*v-DfRkAV{ygB$v!g&YPqO4?Fhm$R}-ekO!0uc7G+7MIB>mo$tAh+Uk0 z-#6Mg-U~$BPc0^<*?S;v1zMeh?VT6fJBRJ%G-M|1o^43{ph4pMNPGEmRyi{NBiLic z{2Mu}WtVd!XLEW!J#XX+SkdOe18D+GgF-2znHO`ajbD%rEtW4@1Jq59JEg2j-MswL z86Z72?TK|1CbymOUDyFv(ZhKL4NQMF{p$MkYt`x3*85+t_P>5zs0}}NU8wRywNT&B z=kJC_?u15G2dn!J-d#)y@wt)6yD+x8?O5St0TM1wRI`2E7+tG1pH4dRIDZO zurV(m-Xvyhx%JV}``kk^v~4~|&=>ABN+)Qs&?RAAzmua|8)-LKz&S$nkZh6@wbeXUZ3o9qd-Uj0*Rj^XZnaT~$Eq}5gn;oW zyMHdme)^r02j=$r*2Exm*7kbm_6ckISb{&x5PX>8G4I@je{G5Zcq$sOc!`E{NwP$V z+$aTI13V~6eo(TQYCPco!d`p-Z(>*bv%8aj|FLvj0Ak4#h} z6ZMcEf>ST*83z8)*E%CzM$%N8>EBF)k83?T#!C7mzE3^-$kW^s0lUnWES#wA^oaWv(VbZZ4ls&^QocG lRa$Fku;PAdXtLsd{(%w|_fzK%GvMj}2AhPo)nR~|{~x*17OVgO literal 0 HcmV?d00001 diff --git a/tests/test_method/__pycache__/test_model_selection.cpython-311-pytest-7.4.2.pyc b/tests/test_method/__pycache__/test_model_selection.cpython-311-pytest-7.4.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d759899a3c454034ebc854d97ca5e833c6a860 GIT binary patch literal 554 zcmZuuy-ve05I!e?rd11qjg5t&1ATxBA%rTiKnDs8Q6wwGmWY(Z!A^zB)PZ;C8-Nfm zRApsiD=Jebd`?mrdRFef`?0?-_FJdZ0&T~GbAF`bCnmKy6se;rD)w)<%C7|D z@rcJf%p;KwyLRDHng$7{v}jS92sVxpw<&!b2XWbXhb5l8x(xgNt>RMomqPONSsdPp zIE#EG#xmr7&Q*jiSH{zX=MREa7r&CEYc0JQ$7vu-0WNYFB7C!zds1G( isoc}4QFg7t!X|{w|112{if93ZG&i|#Z2WBK8qQyZ_m0T` literal 0 HcmV?d00001 diff --git a/tests/test_method/test_base/__pycache__/test_BQAE.cpython-311-pytest-7.4.2.pyc b/tests/test_method/test_base/__pycache__/test_BQAE.cpython-311-pytest-7.4.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1ad6b881cc43ecd4037e2059cc05b86f9622c01 GIT binary patch literal 5871 zcmd5ATWlN0agUGV-I0{3mr+@XA~&ro76L0umSVfMS=lih#iJJH)3nO2h+2fHs zl~h&4O}u+IJ3BiwJ3F^KJLGS|VTnL_|Ktg+Bogu&cAVm?Bad@{yh~J~aym(Jto3BQ zDGx`ozNDYCNBCrbjR;ABwPIFE2HBWL4`sv2FvGoiS9VWw4@W%YG*NvwiR!=OA>>1N z?UIbB+?yoIKg90YC};J&kxyp~Gn4v(mZq9vWb!$n?mv;qDYSIHsN~GdTt=hQsZ^0F zsnRK6%_?S|rs@h=v!~v9Y z{Pd~Cp-j#&m0U_QMhjF^jWN54k3tkF>SokSG3cyMu=1zG`W!3&NJPC4@ffFH`*gAI zqtl-z($|hj4}Mkn8xEjE`qBB{5B>C?Yl$?Hk4ft|0Aa;Z^;u&5{r9l^H9;|MPys<0 zK?p$9m$s+;Ln3YStt8U++<&nJqP}(DW5pKh0=zD3Yh(nb7O@m>Gc$+;H0JTYAhNwn z%tk~a#6)_%xV`6&DivPk%DmhN$z^W7A-UX$myRZE%lv#Jda54xx0)mz*_DRID%xuB zw$9Ov$1-|>m9h^w8amo>m^}_JA@l6o!+s?IazSULBB?w=C^SRejkzgcH zmde{MtxB*Qyv-T6RKLob1CB131;08gZhyZ+i~;jSXAJn9`-!yL=>96yEfH~~J4@s^ zu%>mEhDx8zYrJU992T*%g!Ey*wMO=8t14S za?84VKY>(nfvW`I72p-&l`26<4xzfw6-RQxmr0x$zU!U^?_1lYoUTwi#{Iy~j8a0Q%IZ`OaC!&V*BL^Ek#)#t%4Ujr% zENa9QN|lnVfj(WTj4&A>*QEigvralRKF(bSi3$4udtvP|u;L5tS;V zwncH)u{^ngCDw)V)DjyPnZ?)dF46K9N@EDd5lkSML@&gNB3m&clDoOYv7h-zk$=Hz!2y{OS%Bn4aC zVhPSFtusFw1tRTZ7NcqQ9*A4PX6ImQ=atsZVS71snQ^;k8xlWgV0<5LEuWuB9hv@j2Z zjXc$Ii}~S;dA*RCHS$H8(q_Q}83d-mpq$mrD|yw%FDnLPFlEeNVug~`&m87BR?ex+ z;fq5TfW&y(j^ov4d<#DJ4f14`BmGlh>uY^ys(oi}iM64ZZi!W4uomw9 z+3bVx@cr=c@<4U}!F%&-;pu94dV>gJ7Xu$p{qg97lNaxwyjTnMS3@r?@7wfxUge}M zg0RWA!IQ&@TQk5>4UJSoleI&~TY67C-f-6zPrO|l0l+qh-Q`=ljMfosG6@9#6AKe- zN<690D|Bt58C!0BZ1fTT1c$cG=Lq_{+KsXank;OI9Te<1^GbJi9b$9gw`b#?H&YmuYX$kEQ$w_VV8RtND}i;PtxV;iC$ z7SvT=&{75K-8|OSq2p_jL^YClChBeL&n-c{dkO3R!}=dDLA@8o`V6f<%|%6N(?81d zTLi#Hk_$qdCIB|y=Gsu%5IyB=;ApocIB60c0jf11+s7nkryG#q#Herd11KL@cab`U5J2fPxr+4txz!BrC+0h?K;^c7)2*f&b8d07Cqs z$}18QTTz)hVLNSM=tX(=-rdFblDxKB4bbt~JEA)}e;Ba{Z%(FNO-4X~pcDdRvy-{8 zgCO&QH3SZv0&$*!xGxUCyS}Rk)(N_Tw)bUZKb}`@LUS?Q&}0PEo*OuJ-ckNF74o8Y zc+zjX%Evg5GK#TkV4N{hriM2$ekh}~>b$|!oHsj*yWx$XT!iP0)BK6;-m>~4{X8vCF1@`%t gme$#&CKDGSH2&}K&#Pe*K&ZY7zD?t2Mc3~91<4$V;Q#;t literal 0 HcmV?d00001 diff --git a/tests/test_estimator.py b/tests/test_method/test_base/test_BQAE.py similarity index 95% rename from tests/test_estimator.py rename to tests/test_method/test_base/test_BQAE.py index d13afe2..f28c71b 100644 --- a/tests/test_estimator.py +++ b/tests/test_method/test_base/test_BQAE.py @@ -1,12 +1,12 @@ -import pytest import numpy as np +import pytest import scipy.sparse as sp from sklearn.linear_model import LogisticRegression -from quacc.estimator import BinaryQuantifierAccuracyEstimator +from quacc.method.base import BinaryQuantifierAccuracyEstimator -class TestBinaryQuantifierAccuracyEstimator: +class TestBQAE: @pytest.mark.parametrize( "instances,preds0,preds1,result", [ diff --git a/tests/test_method/test_base/test_MCAE.py b/tests/test_method/test_base/test_MCAE.py new file mode 100644 index 0000000..b0784a2 --- /dev/null +++ b/tests/test_method/test_base/test_MCAE.py @@ -0,0 +1,2 @@ +class TestMCAE: + pass