added aggregation on evaluation report
This commit is contained in:
parent
5234ce1387
commit
d557c6a7d3
|
@ -10,6 +10,14 @@ def f1e(prev):
|
|||
return 1 - f1_score(prev)
|
||||
|
||||
def f1_score(prev):
|
||||
recall = prev[0] / (prev[0] + prev[1])
|
||||
precision = prev[0] / (prev[0] + prev[2])
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
# https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
|
||||
if prev[0] == 0 and prev[1] == 0 and prev[2] == 0:
|
||||
return 1.0
|
||||
elif prev[0] == 0 and prev[1] > 0 and prev[2] == 0:
|
||||
return 0.0
|
||||
elif prev[0] == 0 and prev[1] == 0 and prev[2] > 0:
|
||||
return float('NaN')
|
||||
else:
|
||||
recall = prev[0] / (prev[0] + prev[1])
|
||||
precision = prev[0] / (prev[0] + prev[2])
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
|
|
|
@ -9,6 +9,7 @@ from .estimator import AccuracyEstimator
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
import quacc.error as error
|
||||
import statistics as stats
|
||||
|
||||
|
||||
def estimate(
|
||||
|
@ -50,24 +51,55 @@ def _report_columns(err_names):
|
|||
|
||||
return pd.MultiIndex.from_tuples(cols)
|
||||
|
||||
def _report_avg_groupby_distribution(lst, error_names):
|
||||
def _bprev(s):
|
||||
return (s[("base", "0")], s[("base", "1")])
|
||||
|
||||
def _dict_prev(base_prev, true_prev, estim_prev):
|
||||
prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list(
|
||||
itertools.product(_prev_col_0, _prev_col_1)
|
||||
)
|
||||
def _normalize_prev(r, prev_name):
|
||||
raw_prev = [v for ((k0, k1), v) in r.items() if k0 == prev_name]
|
||||
norm_prev = [v/sum(raw_prev) for v in raw_prev]
|
||||
for n, v in zip(itertools.product([prev_name], _prev_col_1), norm_prev):
|
||||
r[n] = v
|
||||
|
||||
return {
|
||||
k: v
|
||||
for (k, v) in zip(
|
||||
prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0)
|
||||
)
|
||||
}
|
||||
return r
|
||||
|
||||
|
||||
current_bprev = _bprev(lst[0])
|
||||
bprev_cnt = 0
|
||||
g_lst = [[]]
|
||||
for s in lst:
|
||||
if _bprev(s) == current_bprev:
|
||||
g_lst[bprev_cnt].append(s)
|
||||
else:
|
||||
g_lst.append([])
|
||||
bprev_cnt += 1
|
||||
current_bprev = _bprev(s)
|
||||
g_lst[bprev_cnt].append(s)
|
||||
|
||||
r_lst = []
|
||||
for gs in g_lst:
|
||||
assert len(gs) > 0
|
||||
r = {}
|
||||
r[("base", "0")], r[("base", "1")] = _bprev(gs[0])
|
||||
|
||||
for pn in itertools.product(_prev_col_0, _prev_col_1):
|
||||
r[pn] = stats.mean(map(lambda s: s[pn], gs))
|
||||
|
||||
r = _normalize_prev(r, "true")
|
||||
r = _normalize_prev(r, "estim")
|
||||
|
||||
for en in itertools.product(_err_col_0, error_names):
|
||||
r[en] = stats.mean(map(lambda s: s[en], gs))
|
||||
|
||||
r_lst.append(r)
|
||||
|
||||
return r_lst
|
||||
|
||||
def evaluation_report(
|
||||
estimator: AccuracyEstimator,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
error_metrics: Iterable[Union[str, Callable]] = "all",
|
||||
aggregate: bool = True,
|
||||
):
|
||||
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
|
||||
|
||||
|
@ -89,7 +121,16 @@ def evaluation_report(
|
|||
|
||||
lst = []
|
||||
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
|
||||
series = _dict_prev(base_prev, true_prev, estim_prev)
|
||||
prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list(
|
||||
itertools.product(_prev_col_0, _prev_col_1)
|
||||
)
|
||||
|
||||
series = {
|
||||
k: v
|
||||
for (k, v) in zip(
|
||||
prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0)
|
||||
)
|
||||
}
|
||||
for error_name, error_metric in zip(error_names, error_funcs):
|
||||
if error_name == "f1e":
|
||||
series[("errors", "f1e_true")] = error_metric(true_prev)
|
||||
|
@ -101,5 +142,6 @@ def evaluation_report(
|
|||
|
||||
lst.append(series)
|
||||
|
||||
lst = _report_avg_groupby_distribution(lst, error_cols) if aggregate else lst
|
||||
df = pd.DataFrame(lst, columns=df_cols)
|
||||
return df
|
||||
|
|
|
@ -16,19 +16,41 @@ pd.set_option("display.float_format", "{:.4f}".format)
|
|||
|
||||
def test_2(dataset_name):
|
||||
train, test = get_dataset(dataset_name)
|
||||
|
||||
model = LogisticRegression()
|
||||
|
||||
print(f"fitting model {model.__class__.__name__}...", end=" ")
|
||||
model.fit(*train.Xy)
|
||||
estimator = AccuracyEstimator(model, SLD(LogisticRegression()))
|
||||
print("fit")
|
||||
|
||||
qmodel = SLD(LogisticRegression())
|
||||
estimator = AccuracyEstimator(model, qmodel)
|
||||
|
||||
print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ")
|
||||
estimator.fit(train)
|
||||
df = eval.evaluation_report(estimator, APP(test, n_prevalences=11, repeats=100))
|
||||
# print(df.to_string())
|
||||
print("fit")
|
||||
|
||||
n_prevalences = 21
|
||||
repreats = 1000
|
||||
protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats)
|
||||
print( f"Tests:\n\
|
||||
protocol={protocol.__class__.__name__}\n\
|
||||
n_prevalences={n_prevalences}\n\
|
||||
repreats={repreats}\n\
|
||||
executing...\n"
|
||||
)
|
||||
df = eval.evaluation_report(
|
||||
estimator,
|
||||
protocol,
|
||||
aggregate=True,
|
||||
)
|
||||
print(df.to_string())
|
||||
|
||||
|
||||
def main():
|
||||
for dataset_name in [
|
||||
# "hp",
|
||||
# "imdb",
|
||||
"hp",
|
||||
"imdb",
|
||||
"spambase",
|
||||
]:
|
||||
print(dataset_name)
|
||||
|
|
Loading…
Reference in New Issue