qcpanel updated and refactored

This commit is contained in:
Lorenzo Volpi 2023-11-16 17:10:19 +01:00
parent f8ae408a4f
commit a9b926717e
3 changed files with 583 additions and 558 deletions

View File

@ -1,563 +1,13 @@
import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
import panel as pn
import param
from quacc import utils
from quacc.evaluation.comp import CE
from quacc.evaluation.report import DatasetReport
from qcpanel.viewer import QuaccTestViewer
pn.config.design = pn.theme.Bootstrap
pn.config.theme = "dark"
# pn.config.design = pn.theme.Bootstrap
# pn.config.theme = "dark"
pn.config.notifications = True
valid_plot_modes = defaultdict(
lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"]
)
valid_plot_modes["avg"] = [
"delta_train",
"stdev_train",
"delta_test",
"stdev_test",
"shift",
"train_table",
"test_table",
"shift_table",
]
def create_cr_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
prev=None,
):
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
idx = _prevs.index(prev)
cr = dr.crs[idx]
estimators = CE.name[estimators]
if mode is None:
mode = valid_plot_modes[str(prev)][0]
_dpi = 112
if mode == "table":
return pn.pane.DataFrame(
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
elif mode == "shift_table":
return pn.pane.DataFrame(
cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
else:
return pn.pane.Matplotlib(
cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode="scale_height",
# sizing_mode="scale_both",
)
def create_avg_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
):
estimators = CE.name[estimators]
if mode is None:
mode = valid_plot_modes["avg"][0]
if mode == "train_table":
return pn.pane.DataFrame(
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(),
align="center",
)
elif mode == "test_table":
return pn.pane.DataFrame(
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
elif mode == "shift_table":
return pn.pane.DataFrame(
dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
return pn.pane.Matplotlib(
dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode="scale_height",
# sizing_mode="scale_both",
)
def build_widgets(datasets: Dict[str, DatasetReport]):
available_datasets = list(datasets.keys())
dataset_widget = pn.widgets.Select(
name="dataset",
options=available_datasets,
align="center",
)
_dr = datasets[dataset_widget.value]
_data = _dr.data()
_metrics = _data.columns.unique(0)
_estimators = _data.columns.unique(1)
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
metric_widget = pn.widgets.Select(
name="metric",
value="acc",
options=valid_metrics,
align="center",
)
valid_estimators = [e for e in _estimators if e != "ref"]
estimators_widget = pn.widgets.CheckButtonGroup(
name="estimators",
options=valid_estimators,
value=valid_estimators,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
valid_views = [str(round(cr.train_prev[1] * 100)) for cr in _dr.crs]
view_widget = pn.widgets.RadioButtonGroup(
name="view",
options=valid_views + ["avg"],
value="avg",
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
)
@pn.depends(dataset_widget.param.value, watch=True)
def _update_from_dataset(_dataset):
l_dr = datasets[dataset_widget.value]
l_data = l_dr.data()
l_metrics = l_data.columns.unique(0)
l_estimators = l_data.columns.unique(1)
l_valid_estimators = [e for e in l_estimators if e != "ref"]
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
metric_widget.options = l_valid_metrics
metric_widget.value = l_valid_metrics[0]
estimators_widget.options = l_valid_estimators
estimators_widget.value = l_valid_estimators
view_widget.options = l_valid_views + ["avg"]
view_widget.value = "avg"
plot_mode_widget = pn.widgets.RadioButtonGroup(
name="mode",
value=valid_plot_modes["avg"][0],
options=valid_plot_modes["avg"],
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
@pn.depends(view_widget.param.value, watch=True)
def _update_from_view(_view):
_modes = valid_plot_modes[_view]
plot_mode_widget.options = _modes
plot_mode_widget.value = _modes[0]
widget_pane = pn.Column(
dataset_widget,
metric_widget,
pn.Row(
view_widget,
plot_mode_widget,
),
estimators_widget,
)
return (
widget_pane,
{
"dataset": dataset_widget,
"metric": metric_widget,
"view": view_widget,
"plot_mode": plot_mode_widget,
"estimators": estimators_widget,
},
)
def build_plot(
datasets: Dict[str, DatasetReport],
dst: str,
metric: str,
estimators: List[str],
view: str,
mode: str,
):
_dr = datasets[dst]
if view == "avg":
return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators)
else:
prev = int(view)
return create_cr_plots(
_dr, mode=mode, metric=metric, estimators=estimators, prev=prev
)
def build_modal(datasets, dst, metric):
return pn.pane.Str(f"{dst}_{metric}")
def build_save_pane(datasets: Dict[str, DatasetReport], dst: str, metric: str):
return pn.pane.Str(f"{datasets[dst]}_{metric}")
def explore_datasets(root: Path | str):
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
drs = []
for f in os.listdir(root):
if (root / f).is_dir():
drs += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
drs.append(root / f)
# drs.append((str(root),))
return drs
class QuaccTestViewer(param.Parameterized):
dataset = param.Selector()
metric = param.Selector()
estimators = param.ListSelector()
plot_view = param.Selector()
mode = param.Selector()
modal_estimators = param.ListSelector()
modal_plot_view = param.ListSelector()
modal_mode_prev = param.ListSelector(
objects=valid_plot_modes[0], default=valid_plot_modes[0]
)
modal_mode_avg = param.ListSelector(
objects=valid_plot_modes["avg"], default=valid_plot_modes["avg"]
)
param_pane = param.Parameter()
plot_pane = param.Parameter()
modal_pane = param.Parameter()
def __init__(self, **params):
super().__init__(**params)
self.__setup_watchers()
self.__import_datasets()
# self._update_on_dataset()
self.__create_param_pane()
self.__create_modal_pane()
def __save_callback(self, event):
_home = utils.get_quacc_home()
_save_input_val = self.save_input.value_input
_config = "default" if len(_save_input_val) == 0 else _save_input_val
base_path = _home / "output" / self.dataset / _config
os.makedirs(base_path, exist_ok=True)
base_plot = base_path / "plot"
os.makedirs(base_plot, exist_ok=True)
l_dr = self.datasets_[self.dataset]
res = l_dr.to_md(
conf=_config,
metric=self.metric,
estimators=CE.name[self.modal_estimators],
dr_modes=self.modal_mode_avg,
cr_modes=self.modal_mode_prev,
plot_path=base_plot,
)
with open(base_path / f"{self.metric}.md", "w") as f:
f.write(res)
pn.state.notifications.success(f'"{_config}" successfully saved')
def __create_param_pane(self):
self.dataset_widget = pn.Param(
self,
show_name=False,
parameters=["dataset"],
widgets={"dataset": {"widget_type": pn.widgets.Select}},
)
self.metric_widget = pn.Param(
self,
show_name=False,
parameters=["metric"],
widgets={"metric": {"widget_type": pn.widgets.Select}},
)
self.estimators_widgets = pn.Param(
self,
show_name=False,
parameters=["estimators"],
widgets={
"estimators": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.plot_view_widget = pn.Param(
self,
show_name=False,
parameters=["plot_view"],
widgets={
"plot_view": {
"widget_type": pn.widgets.RadioButtonGroup,
"orientation": "vertical",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.mode_widget = pn.Param(
self,
show_name=False,
parameters=["mode"],
widgets={
"mode": {
"widget_type": pn.widgets.RadioButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.param_pane = pn.Column(
self.dataset_widget,
self.metric_widget,
pn.Row(
self.plot_view_widget,
self.mode_widget,
),
self.estimators_widgets,
)
def __create_modal_pane(self):
self.modal_estimators_widgets = pn.Param(
self,
show_name=False,
parameters=["modal_estimators"],
widgets={
"modal_estimators": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.modal_plot_view_widget = pn.Param(
self,
show_name=False,
parameters=["modal_plot_view"],
widgets={
"modal_plot_view": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.modal_mode_prev_widget = pn.Param(
self,
show_name=False,
parameters=["modal_mode_prev"],
widgets={
"modal_mode_prev": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.modal_mode_avg_widget = pn.Param(
self,
show_name=False,
parameters=["modal_mode_avg"],
widgets={
"modal_mode_avg": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.save_input = pn.widgets.TextInput(
name="Configuration Name", placeholder="default", sizing_mode="scale_width"
)
self.save_button = pn.widgets.Button(
name="Saverrr",
sizing_mode="scale_width",
button_style="solid",
button_type="success",
)
self.save_button.on_click(self.__save_callback)
_title_styles = {
"font-size": "14pt",
"font-weight": "bold",
}
self.modal_pane = pn.Column(
pn.Column(
pn.pane.Str("Avg. configuration", styles=_title_styles),
self.modal_mode_avg_widget,
pn.pane.Str("Train prevs. configuration", styles=_title_styles),
pn.Row(
self.modal_plot_view_widget,
self.modal_mode_prev_widget,
),
pn.pane.Str("Estimators configuration", styles=_title_styles),
self.modal_estimators_widgets,
self.save_input,
self.save_button,
width=450,
align="center",
scroll=True,
),
sizing_mode="stretch_both",
)
def __import_datasets(self):
__base_path = "output"
dataset_paths = sorted(
explore_datasets(__base_path), key=lambda t: (-len(t.parts), t)
)
self.datasets_ = {
str(dp.parent.relative_to(Path(__base_path))): DatasetReport.unpickle(dp)
for dp in dataset_paths
}
self.available_datasets = list(self.datasets_.keys())
self.param["dataset"].objects = self.available_datasets
self.dataset = self.available_datasets[0]
def __setup_watchers(self):
self.param.watch(
self._update_on_dataset,
["dataset"],
queued=True,
precedence=0,
)
self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1)
self.param.watch(
self._update_plot,
["dataset", "metric", "estimators", "plot_view", "mode"],
# ["metric", "estimators", "mode"],
onlychanged=False,
precedence=2,
)
self.param.watch(
self._update_on_estimators,
["estimators"],
queued=True,
precedence=3,
)
def _update_on_dataset(self, *events):
l_dr = self.datasets_[self.dataset]
l_data = l_dr.data()
l_metrics = l_data.columns.unique(0)
l_estimators = l_data.columns.unique(1)
l_valid_estimators = [e for e in l_estimators if e != "ref"]
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
self.param["metric"].objects = l_valid_metrics
self.metric = l_valid_metrics[0]
self.param["estimators"].objects = l_valid_estimators
self.estimators = l_valid_estimators
self.param["plot_view"].objects = ["avg"] + l_valid_views
self.plot_view = "avg"
self.param["mode"].objects = valid_plot_modes["avg"]
self.mode = valid_plot_modes["avg"][0]
self.param["modal_estimators"].objects = l_valid_estimators
self.modal_estimators = []
self.param["modal_plot_view"].objects = l_valid_views
self.modal_plot_view = l_valid_views.copy()
def _update_on_view(self, *events):
self.param["mode"].objects = valid_plot_modes[self.plot_view]
self.mode = valid_plot_modes[self.plot_view][0]
def _update_on_estimators(self, *events):
self.modal_estimators = self.estimators.copy()
def _update_plot(self, *events):
self.plot_pane = build_plot(
datasets=self.datasets_,
dst=self.dataset,
metric=self.metric,
estimators=self.estimators,
view=self.plot_view,
mode=self.mode,
)
def get_plot(self):
return self.plot_pane
def get_param_pane(self):
return self.param_pane
def serve(address="localhost"):
qtv = QuaccTestViewer()
@ -565,19 +15,36 @@ def serve(address="localhost"):
def save_callback(event):
app.open_modal()
def refresh_callback(event):
qtv.update_datasets()
save_button = pn.widgets.Button(
name="Save",
sizing_mode="scale_width",
# name="Save",
icon="device-floppy",
icon_size="16px",
# sizing_mode="scale_width",
button_style="solid",
button_type="success",
)
save_button.on_click(save_callback)
app = pn.template.MaterialTemplate(
refresh_button = pn.widgets.Button(
icon="refresh",
icon_size="16px",
button_style="solid",
)
refresh_button.on_click(refresh_callback)
app = pn.template.FastListTemplate(
title="quacc tests",
sidebar=[save_button, qtv.get_param_pane],
main=[qtv.get_plot],
sidebar=[
pn.FlexBox(save_button, refresh_button, flex_direction="row-reverse"),
qtv.get_param_pane,
],
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
modal=[qtv.modal_pane],
theme=pn.theme.DarkTheme,
theme_toggle=False,
)
app.servable()

248
qcpanel/util.py Normal file
View File

@ -0,0 +1,248 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
import panel as pn
from quacc.evaluation.comp import CE
from quacc.evaluation.report import DatasetReport
_plot_sizing_mode = "stretch_both"
valid_plot_modes = defaultdict(
lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"]
)
valid_plot_modes["avg"] = [
"delta_train",
"stdev_train",
"delta_test",
"stdev_test",
"shift",
"train_table",
"test_table",
"shift_table",
]
def create_cr_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
prev=None,
):
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
idx = _prevs.index(prev)
cr = dr.crs[idx]
estimators = CE.name[estimators]
if mode is None:
mode = valid_plot_modes[str(prev)][0]
_dpi = 112
if mode == "table":
return pn.pane.DataFrame(
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
elif mode == "shift_table":
return pn.pane.DataFrame(
cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
else:
return pn.pane.Matplotlib(
cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode=_plot_sizing_mode,
# sizing_mode="scale_height",
# sizing_mode="scale_both",
)
def create_avg_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
):
estimators = CE.name[estimators]
if mode is None:
mode = valid_plot_modes["avg"][0]
if mode == "train_table":
return pn.pane.DataFrame(
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(),
align="center",
)
elif mode == "test_table":
return pn.pane.DataFrame(
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
elif mode == "shift_table":
return pn.pane.DataFrame(
dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
return pn.pane.Matplotlib(
dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
# sizing_mode="scale_height",
sizing_mode=_plot_sizing_mode,
# sizing_mode="scale_both",
)
def build_widgets(datasets: Dict[str, DatasetReport]):
available_datasets = list(datasets.keys())
dataset_widget = pn.widgets.Select(
name="dataset",
options=available_datasets,
align="center",
)
_dr = datasets[dataset_widget.value]
_data = _dr.data()
_metrics = _data.columns.unique(0)
_estimators = _data.columns.unique(1)
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
metric_widget = pn.widgets.Select(
name="metric",
value="acc",
options=valid_metrics,
align="center",
)
valid_estimators = [e for e in _estimators if e != "ref"]
estimators_widget = pn.widgets.CheckButtonGroup(
name="estimators",
options=valid_estimators,
value=valid_estimators,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
valid_views = [str(round(cr.train_prev[1] * 100)) for cr in _dr.crs]
view_widget = pn.widgets.RadioButtonGroup(
name="view",
options=valid_views + ["avg"],
value="avg",
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
)
@pn.depends(dataset_widget.param.value, watch=True)
def _update_from_dataset(_dataset):
l_dr = datasets[dataset_widget.value]
l_data = l_dr.data()
l_metrics = l_data.columns.unique(0)
l_estimators = l_data.columns.unique(1)
l_valid_estimators = [e for e in l_estimators if e != "ref"]
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
metric_widget.options = l_valid_metrics
metric_widget.value = l_valid_metrics[0]
estimators_widget.options = l_valid_estimators
estimators_widget.value = l_valid_estimators
view_widget.options = l_valid_views + ["avg"]
view_widget.value = "avg"
plot_mode_widget = pn.widgets.RadioButtonGroup(
name="mode",
value=valid_plot_modes["avg"][0],
options=valid_plot_modes["avg"],
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
@pn.depends(view_widget.param.value, watch=True)
def _update_from_view(_view):
_modes = valid_plot_modes[_view]
plot_mode_widget.options = _modes
plot_mode_widget.value = _modes[0]
widget_pane = pn.Column(
dataset_widget,
metric_widget,
pn.Row(
view_widget,
plot_mode_widget,
),
estimators_widget,
)
return (
widget_pane,
{
"dataset": dataset_widget,
"metric": metric_widget,
"view": view_widget,
"plot_mode": plot_mode_widget,
"estimators": estimators_widget,
},
)
def build_plot(
datasets: Dict[str, DatasetReport],
dst: str,
metric: str,
estimators: List[str],
view: str,
mode: str,
):
_dr = datasets[dst]
if view == "avg":
return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators)
else:
prev = int(view)
return create_cr_plots(
_dr, mode=mode, metric=metric, estimators=estimators, prev=prev
)
def explore_datasets(root: Path | str):
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
drs = []
for f in os.listdir(root):
if (root / f).is_dir():
drs += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
drs.append(root / f)
# drs.append((str(root),))
return drs

310
qcpanel/viewer.py Normal file
View File

@ -0,0 +1,310 @@
import os
from pathlib import Path
import panel as pn
import param
from qcpanel.util import build_plot, explore_datasets, valid_plot_modes
from quacc.evaluation.comp import CE
from quacc.evaluation.report import DatasetReport
class QuaccTestViewer(param.Parameterized):
dataset = param.Selector()
metric = param.Selector()
estimators = param.ListSelector()
plot_view = param.Selector()
mode = param.Selector()
modal_estimators = param.ListSelector()
modal_plot_view = param.ListSelector()
modal_mode_prev = param.ListSelector(
objects=valid_plot_modes[0], default=valid_plot_modes[0]
)
modal_mode_avg = param.ListSelector(
objects=valid_plot_modes["avg"], default=valid_plot_modes["avg"]
)
param_pane = param.Parameter()
plot_pane = param.Parameter()
modal_pane = param.Parameter()
def __init__(self, **params):
super().__init__(**params)
self.__setup_watchers()
self.update_datasets()
# self._update_on_dataset()
self.__create_param_pane()
self.__create_modal_pane()
def __save_callback(self, event):
_home = Path("output")
_save_input_val = self.save_input.value_input
_config = "default" if len(_save_input_val) == 0 else _save_input_val
base_path = _home / self.dataset / _config
os.makedirs(base_path, exist_ok=True)
base_plot = base_path / "plot"
os.makedirs(base_plot, exist_ok=True)
l_dr = self.datasets_[self.dataset]
res = l_dr.to_md(
conf=_config,
metric=self.metric,
estimators=CE.name[self.modal_estimators],
dr_modes=self.modal_mode_avg,
cr_modes=self.modal_mode_prev,
cr_prevs=self.modal_plot_view,
plot_path=base_plot,
)
with open(base_path / f"{self.metric}.md", "w") as f:
f.write(res)
pn.state.notifications.success(f'"{_config}" successfully saved')
def __create_param_pane(self):
self.dataset_widget = pn.Param(
self,
show_name=False,
parameters=["dataset"],
widgets={"dataset": {"widget_type": pn.widgets.Select}},
)
self.metric_widget = pn.Param(
self,
show_name=False,
parameters=["metric"],
widgets={"metric": {"widget_type": pn.widgets.Select}},
)
self.estimators_widgets = pn.Param(
self,
show_name=False,
parameters=["estimators"],
widgets={
"estimators": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.plot_view_widget = pn.Param(
self,
show_name=False,
parameters=["plot_view"],
widgets={
"plot_view": {
"widget_type": pn.widgets.RadioButtonGroup,
"orientation": "vertical",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.mode_widget = pn.Param(
self,
show_name=False,
parameters=["mode"],
widgets={
"mode": {
"widget_type": pn.widgets.RadioButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.param_pane = pn.Column(
self.dataset_widget,
self.metric_widget,
pn.Row(
self.plot_view_widget,
self.mode_widget,
),
self.estimators_widgets,
)
def __create_modal_pane(self):
self.modal_estimators_widgets = pn.Param(
self,
show_name=False,
parameters=["modal_estimators"],
widgets={
"modal_estimators": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.modal_plot_view_widget = pn.Param(
self,
show_name=False,
parameters=["modal_plot_view"],
widgets={
"modal_plot_view": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.modal_mode_prev_widget = pn.Param(
self,
show_name=False,
parameters=["modal_mode_prev"],
widgets={
"modal_mode_prev": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.modal_mode_avg_widget = pn.Param(
self,
show_name=False,
parameters=["modal_mode_avg"],
widgets={
"modal_mode_avg": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.save_input = pn.widgets.TextInput(
name="Configuration Name", placeholder="default", sizing_mode="scale_width"
)
self.save_button = pn.widgets.Button(
name="Save",
sizing_mode="scale_width",
button_style="solid",
button_type="success",
)
self.save_button.on_click(self.__save_callback)
_title_styles = {
"font-size": "14pt",
"font-weight": "bold",
}
self.modal_pane = pn.Column(
pn.Column(
pn.pane.Str("Avg. configuration", styles=_title_styles),
self.modal_mode_avg_widget,
pn.pane.Str("Train prevs. configuration", styles=_title_styles),
pn.Row(
self.modal_plot_view_widget,
self.modal_mode_prev_widget,
),
pn.pane.Str("Estimators configuration", styles=_title_styles),
self.modal_estimators_widgets,
self.save_input,
self.save_button,
pn.Spacer(height=20),
width=450,
align="center",
scroll=True,
),
sizing_mode="stretch_both",
)
def update_datasets(self):
__base_path = "output"
dataset_paths = sorted(
explore_datasets(__base_path), key=lambda t: (-len(t.parts), t)
)
self.datasets_ = {
str(dp.parent.relative_to(Path(__base_path))): DatasetReport.unpickle(dp)
for dp in dataset_paths
}
self.available_datasets = list(self.datasets_.keys())
self.param["dataset"].objects = self.available_datasets
self.dataset = self.available_datasets[0]
def __setup_watchers(self):
self.param.watch(
self._update_on_dataset,
["dataset"],
queued=True,
precedence=0,
)
self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1)
self.param.watch(
self._update_plot,
["dataset", "metric", "estimators", "plot_view", "mode"],
# ["metric", "estimators", "mode"],
onlychanged=False,
precedence=2,
)
self.param.watch(
self._update_on_estimators,
["estimators"],
queued=True,
precedence=3,
)
def _update_on_dataset(self, *events):
l_dr = self.datasets_[self.dataset]
l_data = l_dr.data()
l_metrics = l_data.columns.unique(0)
l_estimators = l_data.columns.unique(1)
l_valid_estimators = [e for e in l_estimators if e != "ref"]
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
self.param["metric"].objects = l_valid_metrics
self.metric = l_valid_metrics[0]
self.param["estimators"].objects = l_valid_estimators
self.estimators = l_valid_estimators
self.param["plot_view"].objects = ["avg"] + l_valid_views
self.plot_view = "avg"
self.param["mode"].objects = valid_plot_modes["avg"]
self.mode = valid_plot_modes["avg"][0]
self.param["modal_estimators"].objects = l_valid_estimators
self.modal_estimators = []
self.param["modal_plot_view"].objects = l_valid_views
self.modal_plot_view = l_valid_views.copy()
def _update_on_view(self, *events):
self.param["mode"].objects = valid_plot_modes[self.plot_view]
self.mode = valid_plot_modes[self.plot_view][0]
def _update_on_estimators(self, *events):
self.modal_estimators = self.estimators.copy()
def _update_plot(self, *events):
self.plot_pane = build_plot(
datasets=self.datasets_,
dst=self.dataset,
metric=self.metric,
estimators=self.estimators,
view=self.plot_view,
mode=self.mode,
)
def get_plot(self):
return self.plot_pane
def get_param_pane(self):
return self.param_pane