From d22fce90500afde69cbdad08631aad96f5303cf0 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 11 May 2023 21:43:59 +0200 Subject: [PATCH] first test on quantification for accuracy --- .gitignore | 3 + requirements.txt | 36 ++++++++++++ test.py | 146 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 .gitignore create mode 100644 requirements.txt create mode 100644 test.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fc5bb69 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.code-workspace +quavenv/* +*.pdf diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f490e69 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,36 @@ +abstention==0.1.3.1 +astroid==2.15.4 +contourpy==1.0.7 +cycler==0.11.0 +dill==0.3.6 +docstring-to-markdown==0.12 +fonttools==4.39.3 +joblib==1.2.0 +kiwisolver==1.4.4 +lazy-object-proxy==1.9.0 +matplotlib==3.7.1 +numpy==1.24.3 +packaging==23.1 +pandas==2.0.1 +parso==0.8.3 +Pillow==9.5.0 +platformdirs==3.5.0 +pluggy==1.0.0 +pyparsing==3.0.9 +python-dateutil==2.8.2 +pytoolconfig==1.2.5 +pytz==2023.3 +QuaPy==0.1.7 +scikit-learn==1.2.2 +scipy==1.10.1 +six==1.16.0 +snowballstemmer==2.2.0 +threadpoolctl==3.1.0 +toml==0.10.2 +tomlkit==0.11.8 +tqdm==4.65.0 +tzdata==2023.3 +ujson==5.7.0 +whatthepatch==1.0.5 +wrapt==1.15.0 +xlrd==2.0.1 diff --git a/test.py b/test.py new file mode 100644 index 0000000..5f301ac --- /dev/null +++ b/test.py @@ -0,0 +1,146 @@ +import numpy as np +import quapy as qp +import scipy.sparse as sp +from quapy.data import LabelledCollection +from quapy.protocol import APP, AbstractStochasticSeededProtocol +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import cross_val_predict + + +# Extended classes +# +# 0 ~ True 0 +# 1 ~ False 1 +# 2 ~ False 0 +# 3 ~ True 1 +# _____________________ +# | | | +# | True 0 | False 1 | +# |__________|__________| +# | | | +# | False 0 | True 1 | +# |__________|__________| +# +def get_ex_class(classes, true_class, pred_class): + return true_class * classes + pred_class + + +def extend_collection(coll, pred_prob): + n_classes = coll.n_classes + + # n_X = [ X | predicted probs. ] + if isinstance(coll.X, sp.csr_matrix): + pred_prob_csr = sp.csr_matrix(pred_prob) + n_x = sp.hstack([coll.X, pred_prob_csr]) + elif isinstance(coll.X, np.ndarray): + n_x = np.concatenate((coll.X, pred_prob), axis=1) + else: + raise ValueError("Unsupported matrix format") + + # n_y = (exptected y, predicted y) + n_y = [] + for i, true_class in enumerate(coll.y): + pred_class = pred_prob[i].argmax(axis=0) + n_y.append(get_ex_class(n_classes, true_class, pred_class)) + + return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)]) + + +def qf1e_binary(prev): + recall = prev[0] / (prev[0] + prev[1]) + precision = prev[0] / (prev[0] + prev[2]) + + return 1 - 2 * (precision * recall) / (precision + recall) + + +def compute_errors(true_prev, estim_prev, n_instances): + errors = {} + _eps = 1 / (2 * n_instances) + errors = { + "mae": qp.error.mae(true_prev, estim_prev), + "rae": qp.error.rae(true_prev, estim_prev, eps=_eps), + "mrae": qp.error.mrae(true_prev, estim_prev, eps=_eps), + "kld": qp.error.kld(true_prev, estim_prev, eps=_eps), + "nkld": qp.error.nkld(true_prev, estim_prev, eps=_eps), + "true_f1e": qf1e_binary(true_prev), + "estim_f1e": qf1e_binary(estim_prev), + } + + return errors + + +def extend_and_quantify( + model, + q_model, + train, + test: LabelledCollection | AbstractStochasticSeededProtocol, +): + model.fit(*train.Xy) + + pred_prob_train = cross_val_predict(model, *train.Xy, method="predict_proba") + _train = extend_collection(train, pred_prob_train) + + q_model.fit(_train) + + def quantify_extended(test): + pred_prob_test = model.predict_proba(test.X) + _test = extend_collection(test, pred_prob_test) + return _test.prevalence(), q_model.quantify(_test.instances) + + if isinstance(test, LabelledCollection): + _orig_prev, _true_prev, _estim_prev = quantify_extended(test) + _errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0]) + return ([_orig_prev], [_true_prev], [_estim_prev], [_errors]) + + elif isinstance(test, AbstractStochasticSeededProtocol): + orig_prevs, true_prevs, estim_prevs, errors = [], [], [], [] + for index in test.samples_parameters(): + sample = test.sample(index) + _true_prev, _estim_prev = quantify_extended(sample) + + orig_prevs.append(sample.prevalence()) + true_prevs.append(_true_prev) + estim_prevs.append(_estim_prev) + errors.append(compute_errors(_true_prev, _estim_prev, sample.X.shape[0])) + + return orig_prevs, true_prevs, estim_prevs, errors + + +def get_dataset(name): + datasets = { + "spambase": lambda: qp.datasets.fetch_UCIDataset( + "spambase", verbose=False + ).train_test, + "hp": lambda: qp.datasets.fetch_reviews("hp", tfidf=True).train_test, + "imdb": lambda: qp.datasets.fetch_reviews("imdb", tfidf=True).train_test, + } + + try: + return datasets[name]() + except KeyError: + raise KeyError(f"{name} is not available as a dataset") + + +def test_1(): + train, test = get_dataset("spambase") + + orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify( + LogisticRegression(), + qp.method.aggregative.SLD(LogisticRegression()), + train, + APP(test, sample_size=100, n_prevalences=11, repeats=1), + ) + + for orig_prev, true_prev, estim_prev, _errors in zip( + orig_prevs, true_prevs, estim_prevs, errors + ): + print(f"original prevalence:\t{orig_prev}") + print(f"true prevalence:\t{true_prev}") + print(f"estimated prevalence:\t{estim_prev}") + for name, err in _errors.items(): + print(f"{name}={err:.3f}") + print() + + +if __name__ == "__main__": + test_1()