249 lines
6.6 KiB
Python
249 lines
6.6 KiB
Python
import os
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
import panel as pn
|
|
|
|
from quacc.evaluation.comp import CE
|
|
from quacc.evaluation.report import DatasetReport
|
|
|
|
_plot_sizing_mode = "stretch_both"
|
|
valid_plot_modes = defaultdict(
|
|
lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"]
|
|
)
|
|
valid_plot_modes["avg"] = [
|
|
"delta_train",
|
|
"stdev_train",
|
|
"delta_test",
|
|
"stdev_test",
|
|
"shift",
|
|
"train_table",
|
|
"test_table",
|
|
"shift_table",
|
|
]
|
|
|
|
|
|
def create_cr_plots(
|
|
dr: DatasetReport,
|
|
mode="delta",
|
|
metric="acc",
|
|
estimators=None,
|
|
prev=None,
|
|
):
|
|
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
|
idx = _prevs.index(prev)
|
|
cr = dr.crs[idx]
|
|
estimators = CE.name[estimators]
|
|
if mode is None:
|
|
mode = valid_plot_modes[str(prev)][0]
|
|
_dpi = 112
|
|
if mode == "table":
|
|
return pn.pane.DataFrame(
|
|
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
|
align="center",
|
|
)
|
|
elif mode == "shift_table":
|
|
return pn.pane.DataFrame(
|
|
cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
|
align="center",
|
|
)
|
|
else:
|
|
return pn.pane.Matplotlib(
|
|
cr.get_plots(
|
|
mode=mode,
|
|
metric=metric,
|
|
estimators=estimators,
|
|
conf="panel",
|
|
return_fig=True,
|
|
),
|
|
tight=True,
|
|
format="png",
|
|
sizing_mode=_plot_sizing_mode,
|
|
# sizing_mode="scale_height",
|
|
# sizing_mode="scale_both",
|
|
)
|
|
|
|
|
|
def create_avg_plots(
|
|
dr: DatasetReport,
|
|
mode="delta",
|
|
metric="acc",
|
|
estimators=None,
|
|
):
|
|
estimators = CE.name[estimators]
|
|
if mode is None:
|
|
mode = valid_plot_modes["avg"][0]
|
|
|
|
if mode == "train_table":
|
|
return pn.pane.DataFrame(
|
|
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(),
|
|
align="center",
|
|
)
|
|
elif mode == "test_table":
|
|
return pn.pane.DataFrame(
|
|
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
|
align="center",
|
|
)
|
|
elif mode == "shift_table":
|
|
return pn.pane.DataFrame(
|
|
dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
|
align="center",
|
|
)
|
|
return pn.pane.Matplotlib(
|
|
dr.get_plots(
|
|
mode=mode,
|
|
metric=metric,
|
|
estimators=estimators,
|
|
conf="panel",
|
|
return_fig=True,
|
|
),
|
|
tight=True,
|
|
format="png",
|
|
# sizing_mode="scale_height",
|
|
sizing_mode=_plot_sizing_mode,
|
|
# sizing_mode="scale_both",
|
|
)
|
|
|
|
|
|
def build_widgets(datasets: Dict[str, DatasetReport]):
|
|
available_datasets = list(datasets.keys())
|
|
dataset_widget = pn.widgets.Select(
|
|
name="dataset",
|
|
options=available_datasets,
|
|
align="center",
|
|
)
|
|
|
|
_dr = datasets[dataset_widget.value]
|
|
_data = _dr.data()
|
|
_metrics = _data.columns.unique(0)
|
|
_estimators = _data.columns.unique(1)
|
|
|
|
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
|
metric_widget = pn.widgets.Select(
|
|
name="metric",
|
|
value="acc",
|
|
options=valid_metrics,
|
|
align="center",
|
|
)
|
|
|
|
valid_estimators = [e for e in _estimators if e != "ref"]
|
|
estimators_widget = pn.widgets.CheckButtonGroup(
|
|
name="estimators",
|
|
options=valid_estimators,
|
|
value=valid_estimators,
|
|
button_style="outline",
|
|
button_type="primary",
|
|
align="center",
|
|
orientation="vertical",
|
|
sizing_mode="scale_width",
|
|
)
|
|
|
|
valid_views = [str(round(cr.train_prev[1] * 100)) for cr in _dr.crs]
|
|
view_widget = pn.widgets.RadioButtonGroup(
|
|
name="view",
|
|
options=valid_views + ["avg"],
|
|
value="avg",
|
|
button_style="outline",
|
|
button_type="primary",
|
|
align="center",
|
|
orientation="vertical",
|
|
)
|
|
|
|
@pn.depends(dataset_widget.param.value, watch=True)
|
|
def _update_from_dataset(_dataset):
|
|
l_dr = datasets[dataset_widget.value]
|
|
l_data = l_dr.data()
|
|
l_metrics = l_data.columns.unique(0)
|
|
l_estimators = l_data.columns.unique(1)
|
|
|
|
l_valid_estimators = [e for e in l_estimators if e != "ref"]
|
|
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
|
|
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
|
|
|
|
metric_widget.options = l_valid_metrics
|
|
metric_widget.value = l_valid_metrics[0]
|
|
|
|
estimators_widget.options = l_valid_estimators
|
|
estimators_widget.value = l_valid_estimators
|
|
|
|
view_widget.options = l_valid_views + ["avg"]
|
|
view_widget.value = "avg"
|
|
|
|
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
|
name="mode",
|
|
value=valid_plot_modes["avg"][0],
|
|
options=valid_plot_modes["avg"],
|
|
button_style="outline",
|
|
button_type="primary",
|
|
align="center",
|
|
orientation="vertical",
|
|
sizing_mode="scale_width",
|
|
)
|
|
|
|
@pn.depends(view_widget.param.value, watch=True)
|
|
def _update_from_view(_view):
|
|
_modes = valid_plot_modes[_view]
|
|
plot_mode_widget.options = _modes
|
|
plot_mode_widget.value = _modes[0]
|
|
|
|
widget_pane = pn.Column(
|
|
dataset_widget,
|
|
metric_widget,
|
|
pn.Row(
|
|
view_widget,
|
|
plot_mode_widget,
|
|
),
|
|
estimators_widget,
|
|
)
|
|
|
|
return (
|
|
widget_pane,
|
|
{
|
|
"dataset": dataset_widget,
|
|
"metric": metric_widget,
|
|
"view": view_widget,
|
|
"plot_mode": plot_mode_widget,
|
|
"estimators": estimators_widget,
|
|
},
|
|
)
|
|
|
|
|
|
def build_plot(
|
|
datasets: Dict[str, DatasetReport],
|
|
dst: str,
|
|
metric: str,
|
|
estimators: List[str],
|
|
view: str,
|
|
mode: str,
|
|
):
|
|
_dr = datasets[dst]
|
|
if view == "avg":
|
|
return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators)
|
|
else:
|
|
prev = int(view)
|
|
return create_cr_plots(
|
|
_dr, mode=mode, metric=metric, estimators=estimators, prev=prev
|
|
)
|
|
|
|
|
|
def explore_datasets(root: Path | str):
|
|
if isinstance(root, str):
|
|
root = Path(root)
|
|
|
|
if root.name == "plot":
|
|
return []
|
|
|
|
if not root.exists():
|
|
return []
|
|
|
|
drs = []
|
|
for f in os.listdir(root):
|
|
if (root / f).is_dir():
|
|
drs += explore_datasets(root / f)
|
|
elif f == f"{root.name}.pickle":
|
|
drs.append(root / f)
|
|
# drs.append((str(root),))
|
|
|
|
return drs
|