added aggregation on evaluation report

This commit is contained in:
Lorenzo Volpi 2023-06-05 21:54:22 +02:00
parent 5234ce1387
commit d557c6a7d3
3 changed files with 91 additions and 19 deletions

View File

@ -10,6 +10,14 @@ def f1e(prev):
return 1 - f1_score(prev) return 1 - f1_score(prev)
def f1_score(prev): def f1_score(prev):
recall = prev[0] / (prev[0] + prev[1]) # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
precision = prev[0] / (prev[0] + prev[2]) if prev[0] == 0 and prev[1] == 0 and prev[2] == 0:
return 2 * (precision * recall) / (precision + recall) return 1.0
elif prev[0] == 0 and prev[1] > 0 and prev[2] == 0:
return 0.0
elif prev[0] == 0 and prev[1] == 0 and prev[2] > 0:
return float('NaN')
else:
recall = prev[0] / (prev[0] + prev[1])
precision = prev[0] / (prev[0] + prev[2])
return 2 * (precision * recall) / (precision + recall)

View File

@ -9,6 +9,7 @@ from .estimator import AccuracyEstimator
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import quacc.error as error import quacc.error as error
import statistics as stats
def estimate( def estimate(
@ -50,24 +51,55 @@ def _report_columns(err_names):
return pd.MultiIndex.from_tuples(cols) return pd.MultiIndex.from_tuples(cols)
def _report_avg_groupby_distribution(lst, error_names):
def _bprev(s):
return (s[("base", "0")], s[("base", "1")])
def _dict_prev(base_prev, true_prev, estim_prev): def _normalize_prev(r, prev_name):
prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list( raw_prev = [v for ((k0, k1), v) in r.items() if k0 == prev_name]
itertools.product(_prev_col_0, _prev_col_1) norm_prev = [v/sum(raw_prev) for v in raw_prev]
) for n, v in zip(itertools.product([prev_name], _prev_col_1), norm_prev):
r[n] = v
return { return r
k: v
for (k, v) in zip(
prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0)
)
}
current_bprev = _bprev(lst[0])
bprev_cnt = 0
g_lst = [[]]
for s in lst:
if _bprev(s) == current_bprev:
g_lst[bprev_cnt].append(s)
else:
g_lst.append([])
bprev_cnt += 1
current_bprev = _bprev(s)
g_lst[bprev_cnt].append(s)
r_lst = []
for gs in g_lst:
assert len(gs) > 0
r = {}
r[("base", "0")], r[("base", "1")] = _bprev(gs[0])
for pn in itertools.product(_prev_col_0, _prev_col_1):
r[pn] = stats.mean(map(lambda s: s[pn], gs))
r = _normalize_prev(r, "true")
r = _normalize_prev(r, "estim")
for en in itertools.product(_err_col_0, error_names):
r[en] = stats.mean(map(lambda s: s[en], gs))
r_lst.append(r)
return r_lst
def evaluation_report( def evaluation_report(
estimator: AccuracyEstimator, estimator: AccuracyEstimator,
protocol: AbstractStochasticSeededProtocol, protocol: AbstractStochasticSeededProtocol,
error_metrics: Iterable[Union[str, Callable]] = "all", error_metrics: Iterable[Union[str, Callable]] = "all",
aggregate: bool = True,
): ):
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol) base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
@ -89,7 +121,16 @@ def evaluation_report(
lst = [] lst = []
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs): for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
series = _dict_prev(base_prev, true_prev, estim_prev) prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list(
itertools.product(_prev_col_0, _prev_col_1)
)
series = {
k: v
for (k, v) in zip(
prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0)
)
}
for error_name, error_metric in zip(error_names, error_funcs): for error_name, error_metric in zip(error_names, error_funcs):
if error_name == "f1e": if error_name == "f1e":
series[("errors", "f1e_true")] = error_metric(true_prev) series[("errors", "f1e_true")] = error_metric(true_prev)
@ -101,5 +142,6 @@ def evaluation_report(
lst.append(series) lst.append(series)
lst = _report_avg_groupby_distribution(lst, error_cols) if aggregate else lst
df = pd.DataFrame(lst, columns=df_cols) df = pd.DataFrame(lst, columns=df_cols)
return df return df

View File

@ -16,19 +16,41 @@ pd.set_option("display.float_format", "{:.4f}".format)
def test_2(dataset_name): def test_2(dataset_name):
train, test = get_dataset(dataset_name) train, test = get_dataset(dataset_name)
model = LogisticRegression() model = LogisticRegression()
print(f"fitting model {model.__class__.__name__}...", end=" ")
model.fit(*train.Xy) model.fit(*train.Xy)
estimator = AccuracyEstimator(model, SLD(LogisticRegression())) print("fit")
qmodel = SLD(LogisticRegression())
estimator = AccuracyEstimator(model, qmodel)
print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ")
estimator.fit(train) estimator.fit(train)
df = eval.evaluation_report(estimator, APP(test, n_prevalences=11, repeats=100)) print("fit")
# print(df.to_string())
n_prevalences = 21
repreats = 1000
protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats)
print( f"Tests:\n\
protocol={protocol.__class__.__name__}\n\
n_prevalences={n_prevalences}\n\
repreats={repreats}\n\
executing...\n"
)
df = eval.evaluation_report(
estimator,
protocol,
aggregate=True,
)
print(df.to_string()) print(df.to_string())
def main(): def main():
for dataset_name in [ for dataset_name in [
# "hp", "hp",
# "imdb", "imdb",
"spambase", "spambase",
]: ]:
print(dataset_name) print(dataset_name)