dash app updated

This commit is contained in:
Lorenzo Volpi 2023-11-30 03:11:09 +01:00
parent 5847c217ed
commit 186c3cb5f6
1 changed files with 216 additions and 42 deletions

View File

@ -1,17 +1,21 @@
import json
import os import os
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from urllib.parse import parse_qsl, quote, urlencode, urlparse
import dash_bootstrap_components as dbc import dash_bootstrap_components as dbc
import flask
import numpy as np import numpy as np
import pandas as pd from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
import plotly.graph_objects as go from dash.dash_table.Format import Format, Scheme
from dash import Dash, Input, Output, State, callback, dash_table, dcc, html
from quacc import plot from quacc import plot
from quacc.evaluation.estimators import CE from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import ttest_rel
backend = plot.get_backend("plotly") backend = plot.get_backend("plotly")
@ -39,8 +43,17 @@ def get_datasets(root: str | Path) -> List[DatasetReport]:
return dr_paths return dr_paths
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)) return list(
return {str(drp.parent): DatasetReport.unpickle(drp) for drp in dr_paths} map(
lambda p: str(p.parent),
sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)),
)
)
def load_dataset(dataset):
dataset = Path(dataset)
return DatasetReport.unpickle(dataset / f"{dataset.name}.pickle")
def get_fig(dr: DatasetReport, metric, estimators, view, mode): def get_fig(dr: DatasetReport, metric, estimators, view, mode):
@ -67,10 +80,102 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode):
) )
datasets = get_datasets("output") def get_table(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
match (view, mode):
case ("avg", "train_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
case ("avg", "test_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
case ("avg", "shift_table"):
return (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
case ("avg", "stats_table"):
return ttest_rel(dr, metric=metric, estimators=estimators)
case (_, "train_table"):
cr = dr.crs[_prevs.index(view)]
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
case (_, "shift_table"):
cr = dr.crs[_prevs.index(view)]
return (
cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
def get_DataTable(df):
_primary = "#0d6efd"
if df.empty:
return None
df = df.reset_index()
columns = {
c: dict(
id=c,
name=c,
type="numeric",
format=Format(precision=6, scheme=Scheme.exponent),
)
for c in df.columns
}
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
columns = list(columns.values())
data = df.to_dict("records")
return html.Div(
[
dash_table.DataTable(
data=data,
columns=columns,
id="table1",
style_cell={
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
},
style_table={
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
},
)
],
style={
"display": "flex",
"flex-direction": "column",
# "justify-content": "center",
"align-items": "center",
"height": "100vh",
},
)
def get_Graph(fig):
if fig is None:
return None
return dcc.Graph(
id="graph1",
figure=fig,
style={
"margin": 0,
"height": "100vh",
},
)
# datasets = get_datasets("output")
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# app.config.suppress_callback_exceptions = True
sidebar_style = { sidebar_style = {
"top": 0, "top": 0,
"left": 0, "left": 0,
@ -82,22 +187,25 @@ sidebar_style = {
} }
content_style = { content_style = {
# "margin-left": "18vw", "flex": 10,
"flex": 9,
} }
sidebar = html.Div( def parse_href(href: str):
children=[ parse_result = urlparse(href)
params = parse_qsl(parse_result.query)
return dict(params)
def get_sidebar():
return [
html.H4("Parameters:", style={"margin-bottom": "1vw"}), html.H4("Parameters:", style={"margin-bottom": "1vw"}),
dbc.Select( dbc.Select(
options=list(datasets.keys()), # options=list(datasets.keys()),
value=list(datasets.keys())[0], # value=list(datasets.keys())[0],
id="dataset", id="dataset",
), ),
dbc.Select( dbc.Select(
# clearable=False,
# searchable=False,
id="metric", id="metric",
style={"margin-top": "1vh"}, style={"margin-top": "1vh"},
), ),
@ -132,35 +240,57 @@ sidebar = html.Div(
], ],
className="radio-group-wide", className="radio-group-wide",
), ),
], ]
style=sidebar_style,
id="app-sidebar",
)
content = html.Div(
children=[
dcc.Graph(
style={"margin": 0, "height": "100vh"},
id="graph1",
),
],
style=content_style,
)
app.layout = html.Div( app.layout = html.Div(
children=[sidebar, content], [
dcc.Location(id="url", refresh=False),
html.Div(
[
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
html.Div(id="app_content", style=content_style),
],
id="page_layout",
style={"display": "flex", "flexDirection": "row"}, style={"display": "flex", "flexDirection": "row"},
),
]
) )
def apply_param(href, triggered_id, id, curr):
match triggered_id:
case "url":
params = parse_href(href)
return params.get(id, None)
case _:
return curr
@callback(
Output("dataset", "value"),
Output("dataset", "options"),
Input("url", "href"),
)
def update_dataset(href):
datasets = get_datasets("output")
params = parse_href(href)
old_dataset = params.get("dataset", None)
new_dataset = old_dataset if old_dataset in datasets else datasets[0]
return new_dataset, datasets
@callback( @callback(
Output("metric", "options"), Output("metric", "options"),
Output("metric", "value"), Output("metric", "value"),
Input("url", "href"),
Input("dataset", "value"), Input("dataset", "value"),
State("metric", "value"), State("metric", "value"),
) )
def update_metrics(dataset, old_metric): def update_metrics(href, dataset, curr_metric):
dr = datasets[dataset] dr = load_dataset(dataset)
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")] valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0] new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
return valid_metrics, new_metric return valid_metrics, new_metric
@ -169,12 +299,19 @@ def update_metrics(dataset, old_metric):
@callback( @callback(
Output("estimators", "options"), Output("estimators", "options"),
Output("estimators", "value"), Output("estimators", "value"),
Input("url", "href"),
Input("dataset", "value"), Input("dataset", "value"),
Input("metric", "value"), Input("metric", "value"),
State("estimators", "value"), State("estimators", "value"),
) )
def update_estimators(dataset, metric, old_estimators): def update_estimators(href, dataset, metric, curr_estimators):
dr = datasets[dataset] dr = load_dataset(dataset)
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
if isinstance(old_estimators, str):
try:
old_estimators = json.loads(old_estimators)
except JSONDecodeError:
old_estimators = []
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy() valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
new_estimators = valid_estimators[ new_estimators = valid_estimators[
np.isin(valid_estimators, old_estimators) np.isin(valid_estimators, old_estimators)
@ -185,11 +322,13 @@ def update_estimators(dataset, metric, old_estimators):
@callback( @callback(
Output("view", "options"), Output("view", "options"),
Output("view", "value"), Output("view", "value"),
Input("url", "href"),
Input("dataset", "value"), Input("dataset", "value"),
State("view", "value"), State("view", "value"),
) )
def update_view(dataset, old_view): def update_view(href, dataset, curr_view):
dr = datasets[dataset] dr = load_dataset(dataset)
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0] new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view return valid_views, new_view
@ -198,26 +337,61 @@ def update_view(dataset, old_view):
@callback( @callback(
Output("mode", "options"), Output("mode", "options"),
Output("mode", "value"), Output("mode", "value"),
Input("url", "href"),
Input("view", "value"), Input("view", "value"),
State("mode", "value"), State("mode", "value"),
) )
def update_mode(view, old_mode): def update_mode(href, view, curr_mode):
valid_modes = [m for m in valid_plot_modes[view] if not m.endswith("table")] old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
valid_modes = valid_plot_modes[view]
new_mode = old_mode if old_mode in valid_modes else valid_modes[0] new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
return valid_modes, new_mode return valid_modes, new_mode
@callback( @callback(
Output("graph1", "figure"), Output("app_content", "children"),
Output("url", "search"),
Input("dataset", "value"), Input("dataset", "value"),
Input("metric", "value"), Input("metric", "value"),
Input("estimators", "value"), Input("estimators", "value"),
Input("view", "value"), Input("view", "value"),
Input("mode", "value"), Input("mode", "value"),
) )
def update_graph(dataset, metric, estimators, view, mode): def update_content(dataset, metric, estimators, view, mode):
dr = datasets[dataset] search = urlencode(
return get_fig(dr=dr, metric=metric, estimators=estimators, view=view, mode=mode) dict(
dataset=dataset,
metric=metric,
estimators=json.dumps(estimators),
view=view,
mode=mode,
),
quote_via=quote,
)
dr = load_dataset(dataset)
match mode:
case m if m.endswith("table"):
df = get_table(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
dt = get_DataTable(df)
app_content = [] if dt is None else [dt]
case _:
fig = get_fig(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
g = get_Graph(fig)
app_content = [] if g is None else [g]
return app_content, f"?{search}"
def run(): def run():