diff --git a/qcpanel/run.py b/qcpanel/run.py index afca524..8a3958f 100644 --- a/qcpanel/run.py +++ b/qcpanel/run.py @@ -1,563 +1,13 @@ import argparse -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List import panel as pn -import param -from quacc import utils -from quacc.evaluation.comp import CE -from quacc.evaluation.report import DatasetReport +from qcpanel.viewer import QuaccTestViewer -pn.config.design = pn.theme.Bootstrap -pn.config.theme = "dark" +# pn.config.design = pn.theme.Bootstrap +# pn.config.theme = "dark" pn.config.notifications = True -valid_plot_modes = defaultdict( - lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"] -) -valid_plot_modes["avg"] = [ - "delta_train", - "stdev_train", - "delta_test", - "stdev_test", - "shift", - "train_table", - "test_table", - "shift_table", -] - - -def create_cr_plots( - dr: DatasetReport, - mode="delta", - metric="acc", - estimators=None, - prev=None, -): - _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] - idx = _prevs.index(prev) - cr = dr.crs[idx] - estimators = CE.name[estimators] - if mode is None: - mode = valid_plot_modes[str(prev)][0] - _dpi = 112 - if mode == "table": - return pn.pane.DataFrame( - cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(), - align="center", - ) - elif mode == "shift_table": - return pn.pane.DataFrame( - cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(), - align="center", - ) - else: - return pn.pane.Matplotlib( - cr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="panel", - return_fig=True, - ), - tight=True, - format="png", - sizing_mode="scale_height", - # sizing_mode="scale_both", - ) - - -def create_avg_plots( - dr: DatasetReport, - mode="delta", - metric="acc", - estimators=None, -): - estimators = CE.name[estimators] - if mode is None: - mode = valid_plot_modes["avg"][0] - - if mode == "train_table": - return pn.pane.DataFrame( - dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(), - align="center", - ) - elif mode == "test_table": - return pn.pane.DataFrame( - dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(), - align="center", - ) - elif mode == "shift_table": - return pn.pane.DataFrame( - dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(), - align="center", - ) - return pn.pane.Matplotlib( - dr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="panel", - return_fig=True, - ), - tight=True, - format="png", - sizing_mode="scale_height", - # sizing_mode="scale_both", - ) - - -def build_widgets(datasets: Dict[str, DatasetReport]): - available_datasets = list(datasets.keys()) - dataset_widget = pn.widgets.Select( - name="dataset", - options=available_datasets, - align="center", - ) - - _dr = datasets[dataset_widget.value] - _data = _dr.data() - _metrics = _data.columns.unique(0) - _estimators = _data.columns.unique(1) - - valid_metrics = [m for m in _metrics if not m.endswith("_score")] - metric_widget = pn.widgets.Select( - name="metric", - value="acc", - options=valid_metrics, - align="center", - ) - - valid_estimators = [e for e in _estimators if e != "ref"] - estimators_widget = pn.widgets.CheckButtonGroup( - name="estimators", - options=valid_estimators, - value=valid_estimators, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - valid_views = [str(round(cr.train_prev[1] * 100)) for cr in _dr.crs] - view_widget = pn.widgets.RadioButtonGroup( - name="view", - options=valid_views + ["avg"], - value="avg", - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - ) - - @pn.depends(dataset_widget.param.value, watch=True) - def _update_from_dataset(_dataset): - l_dr = datasets[dataset_widget.value] - l_data = l_dr.data() - l_metrics = l_data.columns.unique(0) - l_estimators = l_data.columns.unique(1) - - l_valid_estimators = [e for e in l_estimators if e != "ref"] - l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")] - l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] - - metric_widget.options = l_valid_metrics - metric_widget.value = l_valid_metrics[0] - - estimators_widget.options = l_valid_estimators - estimators_widget.value = l_valid_estimators - - view_widget.options = l_valid_views + ["avg"] - view_widget.value = "avg" - - plot_mode_widget = pn.widgets.RadioButtonGroup( - name="mode", - value=valid_plot_modes["avg"][0], - options=valid_plot_modes["avg"], - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - @pn.depends(view_widget.param.value, watch=True) - def _update_from_view(_view): - _modes = valid_plot_modes[_view] - plot_mode_widget.options = _modes - plot_mode_widget.value = _modes[0] - - widget_pane = pn.Column( - dataset_widget, - metric_widget, - pn.Row( - view_widget, - plot_mode_widget, - ), - estimators_widget, - ) - - return ( - widget_pane, - { - "dataset": dataset_widget, - "metric": metric_widget, - "view": view_widget, - "plot_mode": plot_mode_widget, - "estimators": estimators_widget, - }, - ) - - -def build_plot( - datasets: Dict[str, DatasetReport], - dst: str, - metric: str, - estimators: List[str], - view: str, - mode: str, -): - _dr = datasets[dst] - if view == "avg": - return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators) - else: - prev = int(view) - return create_cr_plots( - _dr, mode=mode, metric=metric, estimators=estimators, prev=prev - ) - - -def build_modal(datasets, dst, metric): - return pn.pane.Str(f"{dst}_{metric}") - - -def build_save_pane(datasets: Dict[str, DatasetReport], dst: str, metric: str): - return pn.pane.Str(f"{datasets[dst]}_{metric}") - - -def explore_datasets(root: Path | str): - if isinstance(root, str): - root = Path(root) - - if root.name == "plot": - return [] - - if not root.exists(): - return [] - - drs = [] - for f in os.listdir(root): - if (root / f).is_dir(): - drs += explore_datasets(root / f) - elif f == f"{root.name}.pickle": - drs.append(root / f) - # drs.append((str(root),)) - - return drs - - -class QuaccTestViewer(param.Parameterized): - dataset = param.Selector() - metric = param.Selector() - estimators = param.ListSelector() - plot_view = param.Selector() - mode = param.Selector() - - modal_estimators = param.ListSelector() - modal_plot_view = param.ListSelector() - modal_mode_prev = param.ListSelector( - objects=valid_plot_modes[0], default=valid_plot_modes[0] - ) - modal_mode_avg = param.ListSelector( - objects=valid_plot_modes["avg"], default=valid_plot_modes["avg"] - ) - - param_pane = param.Parameter() - plot_pane = param.Parameter() - modal_pane = param.Parameter() - - def __init__(self, **params): - super().__init__(**params) - - self.__setup_watchers() - self.__import_datasets() - # self._update_on_dataset() - self.__create_param_pane() - self.__create_modal_pane() - - def __save_callback(self, event): - _home = utils.get_quacc_home() - _save_input_val = self.save_input.value_input - _config = "default" if len(_save_input_val) == 0 else _save_input_val - base_path = _home / "output" / self.dataset / _config - os.makedirs(base_path, exist_ok=True) - base_plot = base_path / "plot" - os.makedirs(base_plot, exist_ok=True) - - l_dr = self.datasets_[self.dataset] - res = l_dr.to_md( - conf=_config, - metric=self.metric, - estimators=CE.name[self.modal_estimators], - dr_modes=self.modal_mode_avg, - cr_modes=self.modal_mode_prev, - plot_path=base_plot, - ) - with open(base_path / f"{self.metric}.md", "w") as f: - f.write(res) - - pn.state.notifications.success(f'"{_config}" successfully saved') - - def __create_param_pane(self): - self.dataset_widget = pn.Param( - self, - show_name=False, - parameters=["dataset"], - widgets={"dataset": {"widget_type": pn.widgets.Select}}, - ) - self.metric_widget = pn.Param( - self, - show_name=False, - parameters=["metric"], - widgets={"metric": {"widget_type": pn.widgets.Select}}, - ) - self.estimators_widgets = pn.Param( - self, - show_name=False, - parameters=["estimators"], - widgets={ - "estimators": { - "widget_type": pn.widgets.CheckButtonGroup, - "orientation": "vertical", - "sizing_mode": "scale_width", - "button_type": "primary", - "button_style": "outline", - } - }, - ) - self.plot_view_widget = pn.Param( - self, - show_name=False, - parameters=["plot_view"], - widgets={ - "plot_view": { - "widget_type": pn.widgets.RadioButtonGroup, - "orientation": "vertical", - "button_type": "primary", - "button_style": "outline", - } - }, - ) - self.mode_widget = pn.Param( - self, - show_name=False, - parameters=["mode"], - widgets={ - "mode": { - "widget_type": pn.widgets.RadioButtonGroup, - "orientation": "vertical", - "sizing_mode": "scale_width", - "button_type": "primary", - "button_style": "outline", - } - }, - align="center", - ) - self.param_pane = pn.Column( - self.dataset_widget, - self.metric_widget, - pn.Row( - self.plot_view_widget, - self.mode_widget, - ), - self.estimators_widgets, - ) - - def __create_modal_pane(self): - self.modal_estimators_widgets = pn.Param( - self, - show_name=False, - parameters=["modal_estimators"], - widgets={ - "modal_estimators": { - "widget_type": pn.widgets.CheckButtonGroup, - "orientation": "vertical", - "sizing_mode": "scale_width", - "button_type": "primary", - "button_style": "outline", - } - }, - ) - self.modal_plot_view_widget = pn.Param( - self, - show_name=False, - parameters=["modal_plot_view"], - widgets={ - "modal_plot_view": { - "widget_type": pn.widgets.CheckButtonGroup, - "orientation": "vertical", - "button_type": "primary", - "button_style": "outline", - } - }, - ) - self.modal_mode_prev_widget = pn.Param( - self, - show_name=False, - parameters=["modal_mode_prev"], - widgets={ - "modal_mode_prev": { - "widget_type": pn.widgets.CheckButtonGroup, - "orientation": "vertical", - "sizing_mode": "scale_width", - "button_type": "primary", - "button_style": "outline", - } - }, - align="center", - ) - self.modal_mode_avg_widget = pn.Param( - self, - show_name=False, - parameters=["modal_mode_avg"], - widgets={ - "modal_mode_avg": { - "widget_type": pn.widgets.CheckButtonGroup, - "orientation": "vertical", - "sizing_mode": "scale_width", - "button_type": "primary", - "button_style": "outline", - } - }, - align="center", - ) - - self.save_input = pn.widgets.TextInput( - name="Configuration Name", placeholder="default", sizing_mode="scale_width" - ) - self.save_button = pn.widgets.Button( - name="Saverrr", - sizing_mode="scale_width", - button_style="solid", - button_type="success", - ) - self.save_button.on_click(self.__save_callback) - - _title_styles = { - "font-size": "14pt", - "font-weight": "bold", - } - self.modal_pane = pn.Column( - pn.Column( - pn.pane.Str("Avg. configuration", styles=_title_styles), - self.modal_mode_avg_widget, - pn.pane.Str("Train prevs. configuration", styles=_title_styles), - pn.Row( - self.modal_plot_view_widget, - self.modal_mode_prev_widget, - ), - pn.pane.Str("Estimators configuration", styles=_title_styles), - self.modal_estimators_widgets, - self.save_input, - self.save_button, - width=450, - align="center", - scroll=True, - ), - sizing_mode="stretch_both", - ) - - def __import_datasets(self): - __base_path = "output" - dataset_paths = sorted( - explore_datasets(__base_path), key=lambda t: (-len(t.parts), t) - ) - self.datasets_ = { - str(dp.parent.relative_to(Path(__base_path))): DatasetReport.unpickle(dp) - for dp in dataset_paths - } - - self.available_datasets = list(self.datasets_.keys()) - self.param["dataset"].objects = self.available_datasets - self.dataset = self.available_datasets[0] - - def __setup_watchers(self): - self.param.watch( - self._update_on_dataset, - ["dataset"], - queued=True, - precedence=0, - ) - self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1) - self.param.watch( - self._update_plot, - ["dataset", "metric", "estimators", "plot_view", "mode"], - # ["metric", "estimators", "mode"], - onlychanged=False, - precedence=2, - ) - self.param.watch( - self._update_on_estimators, - ["estimators"], - queued=True, - precedence=3, - ) - - def _update_on_dataset(self, *events): - l_dr = self.datasets_[self.dataset] - l_data = l_dr.data() - l_metrics = l_data.columns.unique(0) - l_estimators = l_data.columns.unique(1) - - l_valid_estimators = [e for e in l_estimators if e != "ref"] - l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")] - l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] - - self.param["metric"].objects = l_valid_metrics - self.metric = l_valid_metrics[0] - - self.param["estimators"].objects = l_valid_estimators - self.estimators = l_valid_estimators - - self.param["plot_view"].objects = ["avg"] + l_valid_views - self.plot_view = "avg" - - self.param["mode"].objects = valid_plot_modes["avg"] - self.mode = valid_plot_modes["avg"][0] - - self.param["modal_estimators"].objects = l_valid_estimators - self.modal_estimators = [] - - self.param["modal_plot_view"].objects = l_valid_views - self.modal_plot_view = l_valid_views.copy() - - def _update_on_view(self, *events): - self.param["mode"].objects = valid_plot_modes[self.plot_view] - self.mode = valid_plot_modes[self.plot_view][0] - - def _update_on_estimators(self, *events): - self.modal_estimators = self.estimators.copy() - - def _update_plot(self, *events): - self.plot_pane = build_plot( - datasets=self.datasets_, - dst=self.dataset, - metric=self.metric, - estimators=self.estimators, - view=self.plot_view, - mode=self.mode, - ) - - def get_plot(self): - return self.plot_pane - - def get_param_pane(self): - return self.param_pane - def serve(address="localhost"): qtv = QuaccTestViewer() @@ -565,19 +15,36 @@ def serve(address="localhost"): def save_callback(event): app.open_modal() + def refresh_callback(event): + qtv.update_datasets() + save_button = pn.widgets.Button( - name="Save", - sizing_mode="scale_width", + # name="Save", + icon="device-floppy", + icon_size="16px", + # sizing_mode="scale_width", button_style="solid", button_type="success", ) save_button.on_click(save_callback) - app = pn.template.MaterialTemplate( + refresh_button = pn.widgets.Button( + icon="refresh", + icon_size="16px", + button_style="solid", + ) + refresh_button.on_click(refresh_callback) + + app = pn.template.FastListTemplate( title="quacc tests", - sidebar=[save_button, qtv.get_param_pane], - main=[qtv.get_plot], + sidebar=[ + pn.FlexBox(save_button, refresh_button, flex_direction="row-reverse"), + qtv.get_param_pane, + ], + main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")], modal=[qtv.modal_pane], + theme=pn.theme.DarkTheme, + theme_toggle=False, ) app.servable() diff --git a/qcpanel/util.py b/qcpanel/util.py new file mode 100644 index 0000000..27e767f --- /dev/null +++ b/qcpanel/util.py @@ -0,0 +1,248 @@ +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List + +import panel as pn + +from quacc.evaluation.comp import CE +from quacc.evaluation.report import DatasetReport + +_plot_sizing_mode = "stretch_both" +valid_plot_modes = defaultdict( + lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"] +) +valid_plot_modes["avg"] = [ + "delta_train", + "stdev_train", + "delta_test", + "stdev_test", + "shift", + "train_table", + "test_table", + "shift_table", +] + + +def create_cr_plots( + dr: DatasetReport, + mode="delta", + metric="acc", + estimators=None, + prev=None, +): + _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] + idx = _prevs.index(prev) + cr = dr.crs[idx] + estimators = CE.name[estimators] + if mode is None: + mode = valid_plot_modes[str(prev)][0] + _dpi = 112 + if mode == "table": + return pn.pane.DataFrame( + cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(), + align="center", + ) + elif mode == "shift_table": + return pn.pane.DataFrame( + cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(), + align="center", + ) + else: + return pn.pane.Matplotlib( + cr.get_plots( + mode=mode, + metric=metric, + estimators=estimators, + conf="panel", + return_fig=True, + ), + tight=True, + format="png", + sizing_mode=_plot_sizing_mode, + # sizing_mode="scale_height", + # sizing_mode="scale_both", + ) + + +def create_avg_plots( + dr: DatasetReport, + mode="delta", + metric="acc", + estimators=None, +): + estimators = CE.name[estimators] + if mode is None: + mode = valid_plot_modes["avg"][0] + + if mode == "train_table": + return pn.pane.DataFrame( + dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(), + align="center", + ) + elif mode == "test_table": + return pn.pane.DataFrame( + dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(), + align="center", + ) + elif mode == "shift_table": + return pn.pane.DataFrame( + dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(), + align="center", + ) + return pn.pane.Matplotlib( + dr.get_plots( + mode=mode, + metric=metric, + estimators=estimators, + conf="panel", + return_fig=True, + ), + tight=True, + format="png", + # sizing_mode="scale_height", + sizing_mode=_plot_sizing_mode, + # sizing_mode="scale_both", + ) + + +def build_widgets(datasets: Dict[str, DatasetReport]): + available_datasets = list(datasets.keys()) + dataset_widget = pn.widgets.Select( + name="dataset", + options=available_datasets, + align="center", + ) + + _dr = datasets[dataset_widget.value] + _data = _dr.data() + _metrics = _data.columns.unique(0) + _estimators = _data.columns.unique(1) + + valid_metrics = [m for m in _metrics if not m.endswith("_score")] + metric_widget = pn.widgets.Select( + name="metric", + value="acc", + options=valid_metrics, + align="center", + ) + + valid_estimators = [e for e in _estimators if e != "ref"] + estimators_widget = pn.widgets.CheckButtonGroup( + name="estimators", + options=valid_estimators, + value=valid_estimators, + button_style="outline", + button_type="primary", + align="center", + orientation="vertical", + sizing_mode="scale_width", + ) + + valid_views = [str(round(cr.train_prev[1] * 100)) for cr in _dr.crs] + view_widget = pn.widgets.RadioButtonGroup( + name="view", + options=valid_views + ["avg"], + value="avg", + button_style="outline", + button_type="primary", + align="center", + orientation="vertical", + ) + + @pn.depends(dataset_widget.param.value, watch=True) + def _update_from_dataset(_dataset): + l_dr = datasets[dataset_widget.value] + l_data = l_dr.data() + l_metrics = l_data.columns.unique(0) + l_estimators = l_data.columns.unique(1) + + l_valid_estimators = [e for e in l_estimators if e != "ref"] + l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")] + l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] + + metric_widget.options = l_valid_metrics + metric_widget.value = l_valid_metrics[0] + + estimators_widget.options = l_valid_estimators + estimators_widget.value = l_valid_estimators + + view_widget.options = l_valid_views + ["avg"] + view_widget.value = "avg" + + plot_mode_widget = pn.widgets.RadioButtonGroup( + name="mode", + value=valid_plot_modes["avg"][0], + options=valid_plot_modes["avg"], + button_style="outline", + button_type="primary", + align="center", + orientation="vertical", + sizing_mode="scale_width", + ) + + @pn.depends(view_widget.param.value, watch=True) + def _update_from_view(_view): + _modes = valid_plot_modes[_view] + plot_mode_widget.options = _modes + plot_mode_widget.value = _modes[0] + + widget_pane = pn.Column( + dataset_widget, + metric_widget, + pn.Row( + view_widget, + plot_mode_widget, + ), + estimators_widget, + ) + + return ( + widget_pane, + { + "dataset": dataset_widget, + "metric": metric_widget, + "view": view_widget, + "plot_mode": plot_mode_widget, + "estimators": estimators_widget, + }, + ) + + +def build_plot( + datasets: Dict[str, DatasetReport], + dst: str, + metric: str, + estimators: List[str], + view: str, + mode: str, +): + _dr = datasets[dst] + if view == "avg": + return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators) + else: + prev = int(view) + return create_cr_plots( + _dr, mode=mode, metric=metric, estimators=estimators, prev=prev + ) + + +def explore_datasets(root: Path | str): + if isinstance(root, str): + root = Path(root) + + if root.name == "plot": + return [] + + if not root.exists(): + return [] + + drs = [] + for f in os.listdir(root): + if (root / f).is_dir(): + drs += explore_datasets(root / f) + elif f == f"{root.name}.pickle": + drs.append(root / f) + # drs.append((str(root),)) + + return drs diff --git a/qcpanel/viewer.py b/qcpanel/viewer.py new file mode 100644 index 0000000..640745d --- /dev/null +++ b/qcpanel/viewer.py @@ -0,0 +1,310 @@ +import os +from pathlib import Path + +import panel as pn +import param + +from qcpanel.util import build_plot, explore_datasets, valid_plot_modes +from quacc.evaluation.comp import CE +from quacc.evaluation.report import DatasetReport + + +class QuaccTestViewer(param.Parameterized): + dataset = param.Selector() + metric = param.Selector() + estimators = param.ListSelector() + plot_view = param.Selector() + mode = param.Selector() + + modal_estimators = param.ListSelector() + modal_plot_view = param.ListSelector() + modal_mode_prev = param.ListSelector( + objects=valid_plot_modes[0], default=valid_plot_modes[0] + ) + modal_mode_avg = param.ListSelector( + objects=valid_plot_modes["avg"], default=valid_plot_modes["avg"] + ) + + param_pane = param.Parameter() + plot_pane = param.Parameter() + modal_pane = param.Parameter() + + def __init__(self, **params): + super().__init__(**params) + + self.__setup_watchers() + self.update_datasets() + # self._update_on_dataset() + self.__create_param_pane() + self.__create_modal_pane() + + def __save_callback(self, event): + _home = Path("output") + _save_input_val = self.save_input.value_input + _config = "default" if len(_save_input_val) == 0 else _save_input_val + base_path = _home / self.dataset / _config + os.makedirs(base_path, exist_ok=True) + base_plot = base_path / "plot" + os.makedirs(base_plot, exist_ok=True) + + l_dr = self.datasets_[self.dataset] + res = l_dr.to_md( + conf=_config, + metric=self.metric, + estimators=CE.name[self.modal_estimators], + dr_modes=self.modal_mode_avg, + cr_modes=self.modal_mode_prev, + cr_prevs=self.modal_plot_view, + plot_path=base_plot, + ) + with open(base_path / f"{self.metric}.md", "w") as f: + f.write(res) + + pn.state.notifications.success(f'"{_config}" successfully saved') + + def __create_param_pane(self): + self.dataset_widget = pn.Param( + self, + show_name=False, + parameters=["dataset"], + widgets={"dataset": {"widget_type": pn.widgets.Select}}, + ) + self.metric_widget = pn.Param( + self, + show_name=False, + parameters=["metric"], + widgets={"metric": {"widget_type": pn.widgets.Select}}, + ) + self.estimators_widgets = pn.Param( + self, + show_name=False, + parameters=["estimators"], + widgets={ + "estimators": { + "widget_type": pn.widgets.CheckButtonGroup, + "orientation": "vertical", + "sizing_mode": "scale_width", + "button_type": "primary", + "button_style": "outline", + } + }, + ) + self.plot_view_widget = pn.Param( + self, + show_name=False, + parameters=["plot_view"], + widgets={ + "plot_view": { + "widget_type": pn.widgets.RadioButtonGroup, + "orientation": "vertical", + "button_type": "primary", + "button_style": "outline", + } + }, + ) + self.mode_widget = pn.Param( + self, + show_name=False, + parameters=["mode"], + widgets={ + "mode": { + "widget_type": pn.widgets.RadioButtonGroup, + "orientation": "vertical", + "sizing_mode": "scale_width", + "button_type": "primary", + "button_style": "outline", + } + }, + align="center", + ) + self.param_pane = pn.Column( + self.dataset_widget, + self.metric_widget, + pn.Row( + self.plot_view_widget, + self.mode_widget, + ), + self.estimators_widgets, + ) + + def __create_modal_pane(self): + self.modal_estimators_widgets = pn.Param( + self, + show_name=False, + parameters=["modal_estimators"], + widgets={ + "modal_estimators": { + "widget_type": pn.widgets.CheckButtonGroup, + "orientation": "vertical", + "sizing_mode": "scale_width", + "button_type": "primary", + "button_style": "outline", + } + }, + ) + self.modal_plot_view_widget = pn.Param( + self, + show_name=False, + parameters=["modal_plot_view"], + widgets={ + "modal_plot_view": { + "widget_type": pn.widgets.CheckButtonGroup, + "orientation": "vertical", + "button_type": "primary", + "button_style": "outline", + } + }, + ) + self.modal_mode_prev_widget = pn.Param( + self, + show_name=False, + parameters=["modal_mode_prev"], + widgets={ + "modal_mode_prev": { + "widget_type": pn.widgets.CheckButtonGroup, + "orientation": "vertical", + "sizing_mode": "scale_width", + "button_type": "primary", + "button_style": "outline", + } + }, + align="center", + ) + self.modal_mode_avg_widget = pn.Param( + self, + show_name=False, + parameters=["modal_mode_avg"], + widgets={ + "modal_mode_avg": { + "widget_type": pn.widgets.CheckButtonGroup, + "orientation": "vertical", + "sizing_mode": "scale_width", + "button_type": "primary", + "button_style": "outline", + } + }, + align="center", + ) + + self.save_input = pn.widgets.TextInput( + name="Configuration Name", placeholder="default", sizing_mode="scale_width" + ) + self.save_button = pn.widgets.Button( + name="Save", + sizing_mode="scale_width", + button_style="solid", + button_type="success", + ) + self.save_button.on_click(self.__save_callback) + + _title_styles = { + "font-size": "14pt", + "font-weight": "bold", + } + self.modal_pane = pn.Column( + pn.Column( + pn.pane.Str("Avg. configuration", styles=_title_styles), + self.modal_mode_avg_widget, + pn.pane.Str("Train prevs. configuration", styles=_title_styles), + pn.Row( + self.modal_plot_view_widget, + self.modal_mode_prev_widget, + ), + pn.pane.Str("Estimators configuration", styles=_title_styles), + self.modal_estimators_widgets, + self.save_input, + self.save_button, + pn.Spacer(height=20), + width=450, + align="center", + scroll=True, + ), + sizing_mode="stretch_both", + ) + + def update_datasets(self): + __base_path = "output" + dataset_paths = sorted( + explore_datasets(__base_path), key=lambda t: (-len(t.parts), t) + ) + self.datasets_ = { + str(dp.parent.relative_to(Path(__base_path))): DatasetReport.unpickle(dp) + for dp in dataset_paths + } + + self.available_datasets = list(self.datasets_.keys()) + self.param["dataset"].objects = self.available_datasets + self.dataset = self.available_datasets[0] + + def __setup_watchers(self): + self.param.watch( + self._update_on_dataset, + ["dataset"], + queued=True, + precedence=0, + ) + self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1) + self.param.watch( + self._update_plot, + ["dataset", "metric", "estimators", "plot_view", "mode"], + # ["metric", "estimators", "mode"], + onlychanged=False, + precedence=2, + ) + self.param.watch( + self._update_on_estimators, + ["estimators"], + queued=True, + precedence=3, + ) + + def _update_on_dataset(self, *events): + l_dr = self.datasets_[self.dataset] + l_data = l_dr.data() + l_metrics = l_data.columns.unique(0) + l_estimators = l_data.columns.unique(1) + + l_valid_estimators = [e for e in l_estimators if e != "ref"] + l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")] + l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] + + self.param["metric"].objects = l_valid_metrics + self.metric = l_valid_metrics[0] + + self.param["estimators"].objects = l_valid_estimators + self.estimators = l_valid_estimators + + self.param["plot_view"].objects = ["avg"] + l_valid_views + self.plot_view = "avg" + + self.param["mode"].objects = valid_plot_modes["avg"] + self.mode = valid_plot_modes["avg"][0] + + self.param["modal_estimators"].objects = l_valid_estimators + self.modal_estimators = [] + + self.param["modal_plot_view"].objects = l_valid_views + self.modal_plot_view = l_valid_views.copy() + + def _update_on_view(self, *events): + self.param["mode"].objects = valid_plot_modes[self.plot_view] + self.mode = valid_plot_modes[self.plot_view][0] + + def _update_on_estimators(self, *events): + self.modal_estimators = self.estimators.copy() + + def _update_plot(self, *events): + self.plot_pane = build_plot( + datasets=self.datasets_, + dst=self.dataset, + metric=self.metric, + estimators=self.estimators, + view=self.plot_view, + mode=self.mode, + ) + + def get_plot(self): + return self.plot_pane + + def get_param_pane(self): + return self.param_pane