dash updated

This commit is contained in:
Lorenzo Volpi 2023-12-21 16:47:07 +01:00
parent e3b42e0648
commit a5c54a93b7
1 changed files with 85 additions and 37 deletions

View File

@ -2,6 +2,7 @@ import json
import os import os
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError
from operator import index
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from urllib.parse import parse_qsl, quote, urlencode, urlparse from urllib.parse import parse_qsl, quote, urlencode, urlparse
@ -9,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlencode, urlparse
import dash_bootstrap_components as dbc import dash_bootstrap_components as dbc
import numpy as np import numpy as np
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
from dash.dash_table.Format import Format, Scheme from dash.dash_table.Format import Align, Format, Scheme
from quacc import plot from quacc import plot
from quacc.evaluation.estimators import CE from quacc.evaluation.estimators import CE
@ -21,6 +22,10 @@ valid_plot_modes["avg"] = DatasetReport._default_dr_modes
root_folder = "output" root_folder = "output"
def _get_prev_str(prev: np.ndarray):
return str(tuple(np.around(prev, decimals=2)))
def get_datasets(root: str | Path) -> List[DatasetReport]: def get_datasets(root: str | Path) -> List[DatasetReport]:
def load_dataset(dataset): def load_dataset(dataset):
dataset = Path(dataset) dataset = Path(dataset)
@ -63,7 +68,7 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
backend=_backend, backend=_backend,
) )
case (_, _): case (_, _):
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)] cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
return cr.get_plots( return cr.get_plots(
mode=mode, mode=mode,
metric=metric, metric=metric,
@ -76,53 +81,105 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
def get_table(dr: DatasetReport, metric, estimators, view, mode): def get_table(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators] estimators = CE.name[estimators]
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
match (view, mode): match (view, mode):
case ("avg", "train_table"): case ("avg", "train_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() # return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
return dr.train_table(metric=metric, estimators=estimators)
case ("avg", "test_table"): case ("avg", "test_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() # return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
return dr.test_table(metric=metric, estimators=estimators)
case ("avg", "shift_table"): case ("avg", "shift_table"):
return ( # return (
dr.shift_data(metric=metric, estimators=estimators) # dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0) # .groupby(level=0)
.mean() # .mean()
) # )
return dr.shift_table(metric=metric, estimators=estimators)
case ("avg", "stats_table"): case ("avg", "stats_table"):
return wilcoxon(dr, metric=metric, estimators=estimators) return wilcoxon(dr, metric=metric, estimators=estimators)
case (_, "train_table"): case (_, "train_table"):
cr = dr.crs[_prevs.index(view)] cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() # return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
return cr.train_table(metric=metric, estimators=estimators)
case (_, "shift_table"): case (_, "shift_table"):
cr = dr.crs[_prevs.index(view)] cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
return ( # return (
cr.shift_data(metric=metric, estimators=estimators) # cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0) # .groupby(level=0)
.mean() # .mean()
) # )
return cr.shift_table(metric=metric, estimators=estimators)
case (_, "stats_table"): case (_, "stats_table"):
cr = dr.crs[_prevs.index(view)] cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
return wilcoxon(cr, metric=metric, estimators=estimators) return wilcoxon(cr, metric=metric, estimators=estimators)
def get_DataTable(df): def get_DataTable(df, mode):
_primary = "#0d6efd" _primary = "#0d6efd"
if df.empty: if df.empty:
return None return None
_index_name = dict(
train_table="test prev.",
test_table="train prev.",
shift_table="shift",
stats_table="method",
)
df = df.reset_index() df = df.reset_index()
columns = { columns = {
c: dict( c: dict(
id=c, id=c,
name=c, name=_index_name[mode] if c == "index" else c,
type="numeric", type="numeric",
format=Format(precision=6, scheme=Scheme.exponent, nully="nan"), format=Format(precision=6, scheme=Scheme.exponent, nully="nan"),
) )
for c in df.columns for c in df.columns
} }
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed) # columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
columns["index"]["format"] = Format()
columns = list(columns.values()) columns = list(columns.values())
data = df.to_dict("records") data = df.to_dict("records")
for d in data:
if isinstance(d["index"], tuple | list | np.ndarray):
d["index"] = "(" + ", ".join([f"{v:.2f}" for v in d["index"]]) + ")"
elif isinstance(d["index"], float):
d["index"] = f"{d['index']:.2f}"
_style_cell = {
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
}
_style_cell_conditional = [
{
"if": {"column_id": "index"},
"text_align": "center",
},
]
_style_data_conditional = []
if mode != "stats_table":
_style_data_conditional += [
{
"if": {"column_id": "index", "row_index": len(data) - 1},
"font_weight": "bold",
},
{
"if": {"row_index": len(data) - 1},
"background_color": "#0d6efd",
"color": "white",
},
]
_style_table = {
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
}
return html.Div( return html.Div(
[ [
@ -130,19 +187,10 @@ def get_DataTable(df):
data=data, data=data,
columns=columns, columns=columns,
id="table1", id="table1",
style_cell={ style_cell=_style_cell,
"padding": "0 12px", style_cell_conditional=_style_cell_conditional,
"border": "0", style_data_conditional=_style_data_conditional,
"border-bottom": f"1px solid {_primary}", style_table=_style_table,
},
style_table={
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
},
) )
], ],
style={ style={
@ -361,7 +409,7 @@ def update_estimators(href, dataset, metric, curr_estimators, root):
def update_view(href, dataset, curr_view, root): def update_view(href, dataset, curr_view, root):
dr = get_dr(root, dataset) dr = get_dr(root, dataset)
old_view = apply_param(href, ctx.triggered_id, "view", curr_view) old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] valid_views = ["avg"] + [_get_prev_str(cr.train_prev) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0] new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view return valid_views, new_view
@ -412,7 +460,7 @@ def update_content(dataset, metric, estimators, view, mode, root):
view=view, view=view,
mode=mode, mode=mode,
) )
dt = get_DataTable(df) dt = get_DataTable(df, mode)
app_content = [] if dt is None else [dt] app_content = [] if dt is None else [dt]
case _: case _:
fig = get_fig( fig = get_fig(