QuAcc/qcpanel/util.py

152 lines
4.2 KiB
Python

import os
from collections import defaultdict
from pathlib import Path
import panel as pn
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import wilcoxon
_plot_sizing_mode = "stretch_both"
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
def create_plot(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
plot_view=None,
):
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
estimators = CE.name[estimators]
if mode is None:
mode = valid_plot_modes[plot_view][0]
match (plot_view, mode):
case ("avg", _ as plot_mode):
_plot = dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
save_fig=False,
)
case (_, _ as plot_mode):
cr = dr.crs[_prevs.index(int(plot_view))]
_plot = cr.get_plots(
mode=plot_mode,
metric=metric,
estimators=estimators,
conf="panel",
save_fig=False,
)
if _plot is None:
return None
return pn.pane.Matplotlib(
_plot,
tight=True,
format="png",
# sizing_mode="scale_height",
sizing_mode=_plot_sizing_mode,
styles=dict(margin="0"),
# sizing_mode="scale_both",
)
def create_table(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
plot_view=None,
):
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
estimators = CE.name[estimators]
if mode is None:
mode = valid_plot_modes[plot_view][0]
match (plot_view, mode):
case ("avg", "train_table"):
_data = (
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
)
case ("avg", "test_table"):
_data = (
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
)
case ("avg", "shift_table"):
_data = (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
case ("avg", "stats_table"):
_data = wilcoxon(dr, metric=metric, estimators=estimators)
case (_, "train_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = (
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
)
case (_, "shift_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = (
cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
case (_, "stats_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = wilcoxon(cr, metric=metric, estimators=estimators)
return (
pn.Column(
pn.pane.DataFrame(
_data,
align="center",
float_format=lambda v: f"{v:6e}",
styles={"font-size-adjust": "0.62"},
),
sizing_mode="stretch_both",
# scroll=True,
)
if not _data.empty
else None
)
def create_result(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
plot_view=None,
):
match mode:
case m if m.endswith("table"):
return create_table(dr, mode, metric, estimators, plot_view)
case _:
return create_plot(dr, mode, metric, estimators, plot_view)
def explore_datasets(root: Path | str):
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
drs = []
for f in os.listdir(root):
if (root / f).is_dir():
drs += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
drs.append(root / f)
# drs.append((str(root),))
return drs