diff --git a/qcpanel/run.py b/qcpanel/run.py index 8a3958f..bb94a70 100644 --- a/qcpanel/run.py +++ b/qcpanel/run.py @@ -9,8 +9,23 @@ from qcpanel.viewer import QuaccTestViewer pn.config.notifications = True -def serve(address="localhost"): - qtv = QuaccTestViewer() +def app_instance(): + param_init = { + k: v + for k, v in pn.state.location.query_params.items() + if k in ["dataset", "metric", "plot_view", "mode", "estimators"] + } + qtv = QuaccTestViewer(param_init=param_init) + pn.state.location.sync( + qtv, + { + "dataset": "dataset", + "metric": "metric", + "plot_view": "plot_view", + "mode": "mode", + "estimators": "estimators", + }, + ) def save_callback(event): app.open_modal() @@ -48,13 +63,17 @@ def serve(address="localhost"): ) app.servable() + return app + + +def serve(address="localhost"): __port = 33420 __allowed = [address] if address == "localhost": __allowed.append("127.0.0.1") pn.serve( - app, + app_instance, autoreload=True, port=__port, show=False, @@ -76,4 +95,4 @@ def run(): if __name__ == "__main__": - serve() + run() diff --git a/qcpanel/util.py b/qcpanel/util.py index 728932c..7b7d69f 100644 --- a/qcpanel/util.py +++ b/qcpanel/util.py @@ -1,7 +1,6 @@ import os from collections import defaultdict from pathlib import Path -from typing import Dict, List import panel as pn @@ -10,118 +9,112 @@ from quacc.evaluation.report import DatasetReport _plot_sizing_mode = "stretch_both" valid_plot_modes = defaultdict( - lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"] + lambda: [ + "delta_train", + "stdev_train", + "train_table", + "shift", + "shift_table", + "diagonal", + ] ) valid_plot_modes["avg"] = [ "delta_train", "stdev_train", + "train_table", + "shift", + "shift_table", "delta_test", "stdev_test", - "shift", - "train_table", "test_table", - "shift_table", ] -def create_cr_plots( +def create_plots( dr: DatasetReport, mode="delta", metric="acc", estimators=None, - prev=None, + plot_view=None, ): _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] - idx = _prevs.index(prev) - cr = dr.crs[idx] estimators = CE.name[estimators] if mode is None: - mode = valid_plot_modes[str(prev)][0] + mode = valid_plot_modes[plot_view][0] _dpi = 112 - if mode == "table": - return pn.pane.DataFrame( - cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(), - align="center", - ) - elif mode == "shift_table": - return pn.pane.DataFrame( - cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(), - align="center", - ) - else: - return pn.pane.Matplotlib( - cr.get_plots( + match (plot_view, mode): + case ("avg", "train_table"): + _data = ( + dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() + ) + return pn.pane.DataFrame(_data, align="center") if not _data.empty else None + case ("avg", "test_table"): + _data = ( + dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + ) + return pn.pane.DataFrame(_data, align="center") if not _data.empty else None + case ("avg", "shift_table"): + _data = ( + dr.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + return pn.pane.DataFrame(_data, align="center") if not _data.empty else None + case ("avg", _ as plot_mode): + _plot = dr.get_plots( mode=mode, metric=metric, estimators=estimators, conf="panel", return_fig=True, - ), - tight=True, - format="png", - sizing_mode=_plot_sizing_mode, - # sizing_mode="scale_height", - # sizing_mode="scale_both", - ) - - -def create_avg_plots( - dr: DatasetReport, - mode="delta", - metric="acc", - estimators=None, -): - estimators = CE.name[estimators] - if mode is None: - mode = valid_plot_modes["avg"][0] - - if mode == "train_table": - return pn.pane.DataFrame( - dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(), - align="center", - ) - elif mode == "test_table": - return pn.pane.DataFrame( - dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(), - align="center", - ) - elif mode == "shift_table": - return pn.pane.DataFrame( - dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(), - align="center", - ) - return pn.pane.Matplotlib( - dr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="panel", - return_fig=True, - ), - tight=True, - format="png", - # sizing_mode="scale_height", - sizing_mode=_plot_sizing_mode, - # sizing_mode="scale_both", - ) - - -def build_plot( - datasets: Dict[str, DatasetReport], - dst: str, - metric: str, - estimators: List[str], - view: str, - mode: str, -): - _dr = datasets[dst] - if view == "avg": - return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators) - else: - prev = int(view) - return create_cr_plots( - _dr, mode=mode, metric=metric, estimators=estimators, prev=prev - ) + ) + return ( + pn.pane.Matplotlib( + _plot, + tight=True, + format="png", + # sizing_mode="scale_height", + sizing_mode=_plot_sizing_mode, + # sizing_mode="scale_both", + ) + if _plot is not None + else None + ) + case (_, "train_table"): + cr = dr.crs[_prevs.index(int(plot_view))] + _data = ( + cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + ) + return pn.pane.DataFrame(_data, align="center") if not _data.empty else None + case (_, "shift_table"): + cr = dr.crs[_prevs.index(int(plot_view))] + _data = ( + cr.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + return pn.pane.DataFrame(_data, align="center") if not _data.empty else None + case (_, _ as plot_mode): + cr = dr.crs[_prevs.index(int(plot_view))] + _plot = cr.get_plots( + mode=plot_mode, + metric=metric, + estimators=estimators, + conf="panel", + return_fig=True, + ) + return ( + pn.pane.Matplotlib( + _plot, + tight=True, + format="png", + sizing_mode=_plot_sizing_mode, + # sizing_mode="scale_height", + # sizing_mode="scale_both", + ) + if _plot is not None + else None + ) def explore_datasets(root: Path | str): diff --git a/qcpanel/viewer.py b/qcpanel/viewer.py index 640745d..b298151 100644 --- a/qcpanel/viewer.py +++ b/qcpanel/viewer.py @@ -1,10 +1,12 @@ import os from pathlib import Path +import numpy as np +import pandas as pd import panel as pn import param -from qcpanel.util import build_plot, explore_datasets, valid_plot_modes +from qcpanel.util import create_plots, explore_datasets, valid_plot_modes from quacc.evaluation.comp import CE from quacc.evaluation.report import DatasetReport @@ -29,15 +31,24 @@ class QuaccTestViewer(param.Parameterized): plot_pane = param.Parameter() modal_pane = param.Parameter() - def __init__(self, **params): + def __init__(self, param_init=None, **params): super().__init__(**params) + self.param_init = param_init self.__setup_watchers() self.update_datasets() # self._update_on_dataset() self.__create_param_pane() self.__create_modal_pane() + def __get_param_init(self, val): + __b = val in self.param_init + if __b: + setattr(self, val, self.param_init[val]) + del self.param_init[val] + + return __b + def __save_callback(self, event): _home = Path("output") _save_input_val = self.save_input.value_input @@ -233,8 +244,14 @@ class QuaccTestViewer(param.Parameterized): } self.available_datasets = list(self.datasets_.keys()) + _old_dataset = self.dataset self.param["dataset"].objects = self.available_datasets - self.dataset = self.available_datasets[0] + if not self.__get_param_init("dataset"): + self.dataset = ( + _old_dataset + if _old_dataset in self.available_datasets + else self.available_datasets[0] + ) def __setup_watchers(self): self.param.watch( @@ -244,41 +261,57 @@ class QuaccTestViewer(param.Parameterized): precedence=0, ) self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1) + self.param.watch(self._update_on_metric, ["metric"], queued=True, precedence=2) self.param.watch( self._update_plot, ["dataset", "metric", "estimators", "plot_view", "mode"], # ["metric", "estimators", "mode"], onlychanged=False, - precedence=2, + precedence=3, ) self.param.watch( self._update_on_estimators, ["estimators"], queued=True, - precedence=3, + precedence=4, ) def _update_on_dataset(self, *events): l_dr = self.datasets_[self.dataset] l_data = l_dr.data() + l_metrics = l_data.columns.unique(0) - l_estimators = l_data.columns.unique(1) - - l_valid_estimators = [e for e in l_estimators if e != "ref"] l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")] - l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] - + _old_metric = self.metric self.param["metric"].objects = l_valid_metrics - self.metric = l_valid_metrics[0] + if not self.__get_param_init("metric"): + self.metric = ( + _old_metric if _old_metric in l_valid_metrics else l_valid_metrics[0] + ) + _old_estimators = self.estimators + l_valid_estimators = l_dr.data(metric=self.metric).columns.unique(0).to_numpy() + _new_estimators = l_valid_estimators[ + np.isin(l_valid_estimators, _old_estimators) + ].tolist() self.param["estimators"].objects = l_valid_estimators - self.estimators = l_valid_estimators + if not self.__get_param_init("estimators"): + self.estimators = _new_estimators - self.param["plot_view"].objects = ["avg"] + l_valid_views - self.plot_view = "avg" + l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] + l_valid_views = ["avg"] + l_valid_views + _old_view = self.plot_view + self.param["plot_view"].objects = l_valid_views + if not self.__get_param_init("plot_view"): + self.plot_view = _old_view if _old_view in l_valid_views else "avg" - self.param["mode"].objects = valid_plot_modes["avg"] - self.mode = valid_plot_modes["avg"][0] + self.param["mode"].objects = valid_plot_modes[self.plot_view] + if not self.__get_param_init("mode"): + _old_mode = self.mode + if _old_mode in valid_plot_modes[self.plot_view]: + self.mode = _old_mode + else: + self.mode = valid_plot_modes[self.plot_view][0] self.param["modal_estimators"].objects = l_valid_estimators self.modal_estimators = [] @@ -287,21 +320,49 @@ class QuaccTestViewer(param.Parameterized): self.modal_plot_view = l_valid_views.copy() def _update_on_view(self, *events): + _old_mode = self.mode self.param["mode"].objects = valid_plot_modes[self.plot_view] - self.mode = valid_plot_modes[self.plot_view][0] + if _old_mode in valid_plot_modes[self.plot_view]: + self.mode = _old_mode + else: + self.mode = valid_plot_modes[self.plot_view][0] + + def _update_on_metric(self, *events): + _old_estimators = self.estimators + + l_dr = self.datasets_[self.dataset] + l_data: pd.DataFrame = l_dr.data(metric=self.metric) + l_valid_estimators: np.ndarray = l_data.columns.unique(0).to_numpy() + _new_estimators = l_valid_estimators[ + np.isin(l_valid_estimators, _old_estimators) + ].tolist() + self.param["estimators"].objects = l_valid_estimators + self.estimators = _new_estimators def _update_on_estimators(self, *events): self.modal_estimators = self.estimators.copy() def _update_plot(self, *events): - self.plot_pane = build_plot( - datasets=self.datasets_, - dst=self.dataset, - metric=self.metric, - estimators=self.estimators, - view=self.plot_view, - mode=self.mode, + __svg = pn.pane.SVG( + """ + + + + """, + sizing_mode="stretch_both", ) + if len(self.estimators) == 0: + self.plot_pane = __svg + else: + _dr = self.datasets_[self.dataset] + __plot = create_plots( + _dr, + mode=self.mode, + metric=self.metric, + estimators=self.estimators, + plot_view=self.plot_view, + ) + self.plot_pane = __svg if __plot is None else __plot def get_plot(self): return self.plot_pane