From d557c6a7d339f4814edbfb178c2d1560547ab824 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Mon, 5 Jun 2023 21:54:22 +0200 Subject: [PATCH] added aggregation on evaluation report --- quacc/error.py | 14 +++++++--- quacc/evaluation.py | 64 +++++++++++++++++++++++++++++++++++++-------- quacc/main.py | 32 +++++++++++++++++++---- 3 files changed, 91 insertions(+), 19 deletions(-) diff --git a/quacc/error.py b/quacc/error.py index 2ab688e..e8d315a 100644 --- a/quacc/error.py +++ b/quacc/error.py @@ -10,6 +10,14 @@ def f1e(prev): return 1 - f1_score(prev) def f1_score(prev): - recall = prev[0] / (prev[0] + prev[1]) - precision = prev[0] / (prev[0] + prev[2]) - return 2 * (precision * recall) / (precision + recall) + # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure + if prev[0] == 0 and prev[1] == 0 and prev[2] == 0: + return 1.0 + elif prev[0] == 0 and prev[1] > 0 and prev[2] == 0: + return 0.0 + elif prev[0] == 0 and prev[1] == 0 and prev[2] > 0: + return float('NaN') + else: + recall = prev[0] / (prev[0] + prev[1]) + precision = prev[0] / (prev[0] + prev[2]) + return 2 * (precision * recall) / (precision + recall) diff --git a/quacc/evaluation.py b/quacc/evaluation.py index 029d476..cbce5af 100644 --- a/quacc/evaluation.py +++ b/quacc/evaluation.py @@ -9,6 +9,7 @@ from .estimator import AccuracyEstimator import pandas as pd import numpy as np import quacc.error as error +import statistics as stats def estimate( @@ -50,24 +51,55 @@ def _report_columns(err_names): return pd.MultiIndex.from_tuples(cols) +def _report_avg_groupby_distribution(lst, error_names): + def _bprev(s): + return (s[("base", "0")], s[("base", "1")]) -def _dict_prev(base_prev, true_prev, estim_prev): - prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list( - itertools.product(_prev_col_0, _prev_col_1) - ) + def _normalize_prev(r, prev_name): + raw_prev = [v for ((k0, k1), v) in r.items() if k0 == prev_name] + norm_prev = [v/sum(raw_prev) for v in raw_prev] + for n, v in zip(itertools.product([prev_name], _prev_col_1), norm_prev): + r[n] = v - return { - k: v - for (k, v) in zip( - prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0) - ) - } + return r + current_bprev = _bprev(lst[0]) + bprev_cnt = 0 + g_lst = [[]] + for s in lst: + if _bprev(s) == current_bprev: + g_lst[bprev_cnt].append(s) + else: + g_lst.append([]) + bprev_cnt += 1 + current_bprev = _bprev(s) + g_lst[bprev_cnt].append(s) + + r_lst = [] + for gs in g_lst: + assert len(gs) > 0 + r = {} + r[("base", "0")], r[("base", "1")] = _bprev(gs[0]) + + for pn in itertools.product(_prev_col_0, _prev_col_1): + r[pn] = stats.mean(map(lambda s: s[pn], gs)) + + r = _normalize_prev(r, "true") + r = _normalize_prev(r, "estim") + + for en in itertools.product(_err_col_0, error_names): + r[en] = stats.mean(map(lambda s: s[en], gs)) + + r_lst.append(r) + + return r_lst + def evaluation_report( estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol, error_metrics: Iterable[Union[str, Callable]] = "all", + aggregate: bool = True, ): base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol) @@ -89,7 +121,16 @@ def evaluation_report( lst = [] for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs): - series = _dict_prev(base_prev, true_prev, estim_prev) + prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list( + itertools.product(_prev_col_0, _prev_col_1) + ) + + series = { + k: v + for (k, v) in zip( + prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0) + ) + } for error_name, error_metric in zip(error_names, error_funcs): if error_name == "f1e": series[("errors", "f1e_true")] = error_metric(true_prev) @@ -101,5 +142,6 @@ def evaluation_report( lst.append(series) + lst = _report_avg_groupby_distribution(lst, error_cols) if aggregate else lst df = pd.DataFrame(lst, columns=df_cols) return df diff --git a/quacc/main.py b/quacc/main.py index 14fa7d0..16caeb7 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -16,19 +16,41 @@ pd.set_option("display.float_format", "{:.4f}".format) def test_2(dataset_name): train, test = get_dataset(dataset_name) + model = LogisticRegression() + + print(f"fitting model {model.__class__.__name__}...", end=" ") model.fit(*train.Xy) - estimator = AccuracyEstimator(model, SLD(LogisticRegression())) + print("fit") + + qmodel = SLD(LogisticRegression()) + estimator = AccuracyEstimator(model, qmodel) + + print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ") estimator.fit(train) - df = eval.evaluation_report(estimator, APP(test, n_prevalences=11, repeats=100)) - # print(df.to_string()) + print("fit") + + n_prevalences = 21 + repreats = 1000 + protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats) + print( f"Tests:\n\ + protocol={protocol.__class__.__name__}\n\ + n_prevalences={n_prevalences}\n\ + repreats={repreats}\n\ + executing...\n" + ) + df = eval.evaluation_report( + estimator, + protocol, + aggregate=True, + ) print(df.to_string()) def main(): for dataset_name in [ - # "hp", - # "imdb", + "hp", + "imdb", "spambase", ]: print(dataset_name)