2023-10-23 03:14:35 +02:00
|
|
|
from pathlib import Path
|
2023-10-20 23:36:05 +02:00
|
|
|
from typing import List, Tuple
|
|
|
|
|
2023-10-19 02:36:53 +02:00
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
from quacc import plot
|
2023-10-23 03:14:35 +02:00
|
|
|
from quacc.environ import env
|
2023-10-20 23:36:05 +02:00
|
|
|
from quacc.utils import fmt_line_md
|
2023-10-19 03:00:04 +02:00
|
|
|
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
class EvaluationReport:
|
2023-10-23 03:14:35 +02:00
|
|
|
def __init__(self, name=None):
|
2023-10-20 23:36:05 +02:00
|
|
|
self._prevs = []
|
|
|
|
self._dict = {}
|
|
|
|
self._g_prevs = None
|
|
|
|
self._g_dict = None
|
2023-10-23 03:14:35 +02:00
|
|
|
self.name = name if name is not None else "default"
|
2023-10-20 23:36:05 +02:00
|
|
|
self.times = {}
|
2023-10-23 03:14:35 +02:00
|
|
|
self.train_prev = None
|
|
|
|
self.valid_prev = None
|
2023-10-20 23:36:05 +02:00
|
|
|
self.target = "default"
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
def append_row(self, base: np.ndarray | Tuple, **row):
|
|
|
|
if isinstance(base, np.ndarray):
|
|
|
|
base = tuple(base.tolist())
|
2023-10-20 23:36:05 +02:00
|
|
|
self._prevs.append(base)
|
2023-10-19 02:36:53 +02:00
|
|
|
for k, v in row.items():
|
2023-10-20 23:36:05 +02:00
|
|
|
if (k, self.name) in self._dict:
|
|
|
|
self._dict[(k, self.name)].append(v)
|
2023-10-19 02:36:53 +02:00
|
|
|
else:
|
2023-10-20 23:36:05 +02:00
|
|
|
self._dict[(k, self.name)] = [v]
|
|
|
|
self._g_prevs = None
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
@property
|
|
|
|
def columns(self):
|
2023-10-20 23:36:05 +02:00
|
|
|
return self._dict.keys()
|
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
def group_by_prevs(self, metric: str = None):
|
2023-10-20 23:36:05 +02:00
|
|
|
if self._g_dict is None:
|
|
|
|
self._g_prevs = []
|
|
|
|
self._g_dict = {k: [] for k in self._dict.keys()}
|
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
for col, vals in self._dict.items():
|
|
|
|
col_grouped = {}
|
|
|
|
for bp, v in zip(self._prevs, vals):
|
|
|
|
if bp not in col_grouped:
|
|
|
|
col_grouped[bp] = []
|
|
|
|
col_grouped[bp].append(v)
|
2023-10-20 23:36:05 +02:00
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
self._g_dict[col] = [
|
|
|
|
vs
|
|
|
|
for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1])
|
|
|
|
]
|
2023-10-20 23:36:05 +02:00
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
self._g_prevs = sorted(
|
|
|
|
[(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()],
|
|
|
|
key=lambda bp: bp[1],
|
|
|
|
)
|
|
|
|
|
|
|
|
# last_end = 0
|
|
|
|
# for ind, bp in enumerate(self._prevs):
|
|
|
|
# if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
|
|
|
|
# continue
|
|
|
|
|
|
|
|
# self._g_prevs.append(bp)
|
|
|
|
# for col in self._dict.keys():
|
|
|
|
# self._g_dict[col].append(
|
|
|
|
# stats.mean(self._dict[col][last_end : ind + 1])
|
|
|
|
# )
|
|
|
|
|
|
|
|
# last_end = ind + 1
|
2023-10-20 23:36:05 +02:00
|
|
|
|
|
|
|
filtered_g_dict = self._g_dict
|
|
|
|
if metric is not None:
|
|
|
|
filtered_g_dict = {
|
|
|
|
c1: ls for ((c0, c1), ls) in self._g_dict.items() if c0 == metric
|
|
|
|
}
|
|
|
|
|
|
|
|
return self._g_prevs, filtered_g_dict
|
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
def avg_by_prevs(self, metric: str = None):
|
|
|
|
g_prevs, g_dict = self.group_by_prevs(metric=metric)
|
|
|
|
|
|
|
|
a_dict = {}
|
|
|
|
for col, vals in g_dict.items():
|
|
|
|
a_dict[col] = [np.mean(vs) for vs in vals]
|
|
|
|
|
|
|
|
return g_prevs, a_dict
|
|
|
|
|
|
|
|
def avg_all(self, metric: str = None):
|
|
|
|
f_dict = self._dict
|
|
|
|
if metric is not None:
|
|
|
|
f_dict = {c1: ls for ((c0, c1), ls) in self._dict.items() if c0 == metric}
|
|
|
|
|
|
|
|
a_dict = {}
|
|
|
|
for col, vals in f_dict.items():
|
|
|
|
a_dict[col] = [np.mean(vals)]
|
|
|
|
|
|
|
|
return a_dict
|
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
def get_dataframe(self, metric="acc"):
|
2023-10-23 03:14:35 +02:00
|
|
|
g_prevs, g_dict = self.avg_by_prevs(metric=metric)
|
|
|
|
a_dict = self.avg_all(metric=metric)
|
|
|
|
for col in g_dict.keys():
|
|
|
|
g_dict[col].extend(a_dict[col])
|
2023-10-20 23:36:05 +02:00
|
|
|
return pd.DataFrame(
|
|
|
|
g_dict,
|
2023-10-23 03:14:35 +02:00
|
|
|
index=g_prevs + ["tot"],
|
2023-10-20 23:36:05 +02:00
|
|
|
columns=g_dict.keys(),
|
|
|
|
)
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
def get_plot(self, mode="delta", metric="acc") -> Path:
|
|
|
|
if mode == "delta":
|
|
|
|
g_prevs, g_dict = self.group_by_prevs(metric=metric)
|
|
|
|
return plot.plot_delta(
|
|
|
|
g_prevs,
|
|
|
|
g_dict,
|
|
|
|
metric=metric,
|
|
|
|
name=self.name,
|
|
|
|
train_prev=self.train_prev,
|
|
|
|
)
|
|
|
|
elif mode == "diagonal":
|
|
|
|
_, g_dict = self.avg_by_prevs(metric=metric + "_score")
|
|
|
|
f_dict = {k: v for k, v in g_dict.items() if k != "ref"}
|
|
|
|
referece = g_dict["ref"]
|
|
|
|
return plot.plot_diagonal(
|
|
|
|
referece,
|
|
|
|
f_dict,
|
|
|
|
metric=metric,
|
|
|
|
name=self.name,
|
|
|
|
train_prev=self.train_prev,
|
|
|
|
)
|
|
|
|
elif mode == "shift":
|
|
|
|
g_prevs, g_dict = self.avg_by_prevs(metric=metric)
|
|
|
|
return plot.plot_shift(
|
|
|
|
g_prevs,
|
|
|
|
g_dict,
|
|
|
|
metric=metric,
|
|
|
|
name=self.name,
|
|
|
|
train_prev=self.train_prev,
|
|
|
|
)
|
2023-10-20 23:36:41 +02:00
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
def to_md(self, *metrics):
|
|
|
|
res = ""
|
2023-10-23 03:14:35 +02:00
|
|
|
res += fmt_line_md(f"train: {str(self.train_prev)}")
|
|
|
|
res += fmt_line_md(f"validation: {str(self.valid_prev)}")
|
2023-10-20 23:36:05 +02:00
|
|
|
for k, v in self.times.items():
|
|
|
|
res += fmt_line_md(f"{k}: {v:.3f}s")
|
|
|
|
res += "\n"
|
|
|
|
for m in metrics:
|
|
|
|
res += self.get_dataframe(metric=m).to_html() + "\n\n"
|
2023-10-23 03:14:35 +02:00
|
|
|
op_delta = self.get_plot(mode="delta", metric=m)
|
|
|
|
res += f"![plot_delta]({str(op_delta.relative_to(env.OUT_DIR))})\n"
|
|
|
|
op_diag = self.get_plot(mode="diagonal", metric=m)
|
|
|
|
res += f"![plot_diagonal]({str(op_diag.relative_to(env.OUT_DIR))})\n"
|
|
|
|
op_shift = self.get_plot(mode="shift", metric=m)
|
|
|
|
res += f"![plot_shift]({str(op_shift.relative_to(env.OUT_DIR))})\n"
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
return res
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
def merge(self, other):
|
2023-10-20 23:36:05 +02:00
|
|
|
if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)):
|
2023-10-19 02:36:53 +02:00
|
|
|
raise ValueError("other has not same base prevalences of self")
|
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys()))
|
|
|
|
if len(inters_keys) > 0:
|
|
|
|
raise ValueError(f"self and other have matching keys {str(inters_keys)}.")
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
report = EvaluationReport()
|
2023-10-20 23:36:05 +02:00
|
|
|
report._prevs = self._prevs
|
|
|
|
report._dict = self._dict | other._dict
|
2023-10-19 02:36:53 +02:00
|
|
|
return report
|
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
@staticmethod
|
2023-10-23 03:14:35 +02:00
|
|
|
def combine_reports(*args, name="default", train_prev=None, valid_prev=None):
|
2023-10-20 23:36:05 +02:00
|
|
|
er = args[0]
|
|
|
|
for r in args[1:]:
|
|
|
|
er = er.merge(r)
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
er.name = name
|
2023-10-23 03:14:35 +02:00
|
|
|
er.train_prev = train_prev
|
|
|
|
er.valid_prev = valid_prev
|
2023-10-20 23:36:05 +02:00
|
|
|
return er
|
2023-10-19 03:00:04 +02:00
|
|
|
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
class DatasetReport:
|
|
|
|
def __init__(self, name):
|
|
|
|
self.name = name
|
|
|
|
self.ers: List[EvaluationReport] = []
|
2023-10-19 03:00:04 +02:00
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
def add(self, er: EvaluationReport):
|
|
|
|
self.ers.append(er)
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-19 03:00:04 +02:00
|
|
|
def to_md(self, *metrics):
|
2023-10-20 23:36:05 +02:00
|
|
|
res = f"{self.name}\n\n"
|
|
|
|
for er in self.ers:
|
|
|
|
res += f"{er.to_md(*metrics)}\n\n"
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-19 03:00:04 +02:00
|
|
|
return res
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
def __iter__(self):
|
|
|
|
return (er for er in self.ers)
|