diff --git a/qcdash/app.py b/qcdash/app.py index 32b24b3..e3d46d1 100644 --- a/qcdash/app.py +++ b/qcdash/app.py @@ -16,8 +16,6 @@ from quacc.evaluation.estimators import CE from quacc.evaluation.report import CompReport, DatasetReport from quacc.evaluation.stats import wilcoxon -backend = plot.get_backend("plotly") - valid_plot_modes = defaultdict(lambda: CompReport._default_modes) valid_plot_modes["avg"] = DatasetReport._default_dr_modes @@ -50,7 +48,9 @@ def get_datasets(root: str | Path) -> List[DatasetReport]: return {str(drp.parent): load_dataset(drp) for drp in dr_paths} -def get_fig(dr: DatasetReport, metric, estimators, view, mode): +def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None): + _backend = backend or plot.get_backend("plotly") + print(_backend) estimators = CE.name[estimators] match (view, mode): case ("avg", _): @@ -60,7 +60,7 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode): estimators=estimators, conf="plotly", save_fig=False, - backend=backend, + backend=_backend, ) case (_, _): cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)] @@ -70,7 +70,7 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode): estimators=estimators, conf="plotly", save_fig=False, - backend=backend, + backend=_backend, ) diff --git a/quacc/plot/plot.py b/quacc/plot/plot.py index 1bd2369..1f70aa6 100644 --- a/quacc/plot/plot.py +++ b/quacc/plot/plot.py @@ -5,8 +5,8 @@ from quacc.plot.plotly import PlotlyPlot __backend: BasePlot = MplPlot() -def get_backend(be, theme=None): - match be: +def get_backend(name, theme=None): + match name: case "matplotlib" | "mpl": return MplPlot() case "plotly":