plot backend refacored

This commit is contained in:
Lorenzo Volpi 2023-12-01 10:48:50 +01:00
parent 932b235167
commit e69e8381e3
2 changed files with 7 additions and 7 deletions

View File

@ -16,8 +16,6 @@ from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import wilcoxon from quacc.evaluation.stats import wilcoxon
backend = plot.get_backend("plotly")
valid_plot_modes = defaultdict(lambda: CompReport._default_modes) valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes valid_plot_modes["avg"] = DatasetReport._default_dr_modes
@ -50,7 +48,9 @@ def get_datasets(root: str | Path) -> List[DatasetReport]:
return {str(drp.parent): load_dataset(drp) for drp in dr_paths} return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
def get_fig(dr: DatasetReport, metric, estimators, view, mode): def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
_backend = backend or plot.get_backend("plotly")
print(_backend)
estimators = CE.name[estimators] estimators = CE.name[estimators]
match (view, mode): match (view, mode):
case ("avg", _): case ("avg", _):
@ -60,7 +60,7 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode):
estimators=estimators, estimators=estimators,
conf="plotly", conf="plotly",
save_fig=False, save_fig=False,
backend=backend, backend=_backend,
) )
case (_, _): case (_, _):
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)] cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
@ -70,7 +70,7 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode):
estimators=estimators, estimators=estimators,
conf="plotly", conf="plotly",
save_fig=False, save_fig=False,
backend=backend, backend=_backend,
) )

View File

@ -5,8 +5,8 @@ from quacc.plot.plotly import PlotlyPlot
__backend: BasePlot = MplPlot() __backend: BasePlot = MplPlot()
def get_backend(be, theme=None): def get_backend(name, theme=None):
match be: match name:
case "matplotlib" | "mpl": case "matplotlib" | "mpl":
return MplPlot() return MplPlot()
case "plotly": case "plotly":