added stats_table

This commit is contained in:
Lorenzo Volpi 2023-11-27 03:26:41 +01:00
parent 4b7e236bc8
commit 451f5f38f1
1 changed files with 8 additions and 22 deletions

View File

@ -4,30 +4,13 @@ from pathlib import Path
import panel as pn import panel as pn
from quacc.evaluation.comp import CE from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import ttest_rel
_plot_sizing_mode = "stretch_both" _plot_sizing_mode = "stretch_both"
valid_plot_modes = defaultdict( valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
lambda: [ valid_plot_modes["avg"] = DatasetReport._default_dr_modes
"delta_train",
"stdev_train",
"train_table",
"shift",
"shift_table",
"diagonal",
]
)
valid_plot_modes["avg"] = [
"delta_train",
"stdev_train",
"train_table",
"shift",
"shift_table",
"delta_test",
"stdev_test",
"test_table",
]
def create_plots( def create_plots(
@ -60,6 +43,9 @@ def create_plots(
.mean() .mean()
) )
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", "stats_table"):
_data = ttest_rel(dr, metric=metric, estimators=estimators)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", _ as plot_mode): case ("avg", _ as plot_mode):
_plot = dr.get_plots( _plot = dr.get_plots(
mode=mode, mode=mode,