dash app updated

This commit is contained in:
Lorenzo Volpi 2023-11-30 03:11:09 +01:00
parent 5847c217ed
commit 186c3cb5f6
1 changed files with 216 additions and 42 deletions

View File

@ -1,17 +1,21 @@
import json
import os
from collections import defaultdict
from json import JSONDecodeError
from pathlib import Path
from typing import List
from urllib.parse import parse_qsl, quote, urlencode, urlparse
import dash_bootstrap_components as dbc
import flask
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from dash import Dash, Input, Output, State, callback, dash_table, dcc, html
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
from dash.dash_table.Format import Format, Scheme
from quacc import plot
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import ttest_rel
backend = plot.get_backend("plotly")
@ -39,8 +43,17 @@ def get_datasets(root: str | Path) -> List[DatasetReport]:
return dr_paths
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
return {str(drp.parent): DatasetReport.unpickle(drp) for drp in dr_paths}
return list(
map(
lambda p: str(p.parent),
sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)),
)
)
def load_dataset(dataset):
dataset = Path(dataset)
return DatasetReport.unpickle(dataset / f"{dataset.name}.pickle")
def get_fig(dr: DatasetReport, metric, estimators, view, mode):
@ -67,10 +80,102 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode):
)
datasets = get_datasets("output")
def get_table(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
match (view, mode):
case ("avg", "train_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
case ("avg", "test_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
case ("avg", "shift_table"):
return (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
case ("avg", "stats_table"):
return ttest_rel(dr, metric=metric, estimators=estimators)
case (_, "train_table"):
cr = dr.crs[_prevs.index(view)]
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
case (_, "shift_table"):
cr = dr.crs[_prevs.index(view)]
return (
cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
def get_DataTable(df):
_primary = "#0d6efd"
if df.empty:
return None
df = df.reset_index()
columns = {
c: dict(
id=c,
name=c,
type="numeric",
format=Format(precision=6, scheme=Scheme.exponent),
)
for c in df.columns
}
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
columns = list(columns.values())
data = df.to_dict("records")
return html.Div(
[
dash_table.DataTable(
data=data,
columns=columns,
id="table1",
style_cell={
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
},
style_table={
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
},
)
],
style={
"display": "flex",
"flex-direction": "column",
# "justify-content": "center",
"align-items": "center",
"height": "100vh",
},
)
def get_Graph(fig):
if fig is None:
return None
return dcc.Graph(
id="graph1",
figure=fig,
style={
"margin": 0,
"height": "100vh",
},
)
# datasets = get_datasets("output")
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# app.config.suppress_callback_exceptions = True
sidebar_style = {
"top": 0,
"left": 0,
@ -82,22 +187,25 @@ sidebar_style = {
}
content_style = {
# "margin-left": "18vw",
"flex": 9,
"flex": 10,
}
sidebar = html.Div(
children=[
def parse_href(href: str):
parse_result = urlparse(href)
params = parse_qsl(parse_result.query)
return dict(params)
def get_sidebar():
return [
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
dbc.Select(
options=list(datasets.keys()),
value=list(datasets.keys())[0],
# options=list(datasets.keys()),
# value=list(datasets.keys())[0],
id="dataset",
),
dbc.Select(
# clearable=False,
# searchable=False,
id="metric",
style={"margin-top": "1vh"},
),
@ -132,35 +240,57 @@ sidebar = html.Div(
],
className="radio-group-wide",
),
],
style=sidebar_style,
id="app-sidebar",
)
]
content = html.Div(
children=[
dcc.Graph(
style={"margin": 0, "height": "100vh"},
id="graph1",
),
],
style=content_style,
)
app.layout = html.Div(
children=[sidebar, content],
[
dcc.Location(id="url", refresh=False),
html.Div(
[
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
html.Div(id="app_content", style=content_style),
],
id="page_layout",
style={"display": "flex", "flexDirection": "row"},
),
]
)
def apply_param(href, triggered_id, id, curr):
match triggered_id:
case "url":
params = parse_href(href)
return params.get(id, None)
case _:
return curr
@callback(
Output("dataset", "value"),
Output("dataset", "options"),
Input("url", "href"),
)
def update_dataset(href):
datasets = get_datasets("output")
params = parse_href(href)
old_dataset = params.get("dataset", None)
new_dataset = old_dataset if old_dataset in datasets else datasets[0]
return new_dataset, datasets
@callback(
Output("metric", "options"),
Output("metric", "value"),
Input("url", "href"),
Input("dataset", "value"),
State("metric", "value"),
)
def update_metrics(dataset, old_metric):
dr = datasets[dataset]
def update_metrics(href, dataset, curr_metric):
dr = load_dataset(dataset)
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
return valid_metrics, new_metric
@ -169,12 +299,19 @@ def update_metrics(dataset, old_metric):
@callback(
Output("estimators", "options"),
Output("estimators", "value"),
Input("url", "href"),
Input("dataset", "value"),
Input("metric", "value"),
State("estimators", "value"),
)
def update_estimators(dataset, metric, old_estimators):
dr = datasets[dataset]
def update_estimators(href, dataset, metric, curr_estimators):
dr = load_dataset(dataset)
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
if isinstance(old_estimators, str):
try:
old_estimators = json.loads(old_estimators)
except JSONDecodeError:
old_estimators = []
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
new_estimators = valid_estimators[
np.isin(valid_estimators, old_estimators)
@ -185,11 +322,13 @@ def update_estimators(dataset, metric, old_estimators):
@callback(
Output("view", "options"),
Output("view", "value"),
Input("url", "href"),
Input("dataset", "value"),
State("view", "value"),
)
def update_view(dataset, old_view):
dr = datasets[dataset]
def update_view(href, dataset, curr_view):
dr = load_dataset(dataset)
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view
@ -198,26 +337,61 @@ def update_view(dataset, old_view):
@callback(
Output("mode", "options"),
Output("mode", "value"),
Input("url", "href"),
Input("view", "value"),
State("mode", "value"),
)
def update_mode(view, old_mode):
valid_modes = [m for m in valid_plot_modes[view] if not m.endswith("table")]
def update_mode(href, view, curr_mode):
old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
valid_modes = valid_plot_modes[view]
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
return valid_modes, new_mode
@callback(
Output("graph1", "figure"),
Output("app_content", "children"),
Output("url", "search"),
Input("dataset", "value"),
Input("metric", "value"),
Input("estimators", "value"),
Input("view", "value"),
Input("mode", "value"),
)
def update_graph(dataset, metric, estimators, view, mode):
dr = datasets[dataset]
return get_fig(dr=dr, metric=metric, estimators=estimators, view=view, mode=mode)
def update_content(dataset, metric, estimators, view, mode):
search = urlencode(
dict(
dataset=dataset,
metric=metric,
estimators=json.dumps(estimators),
view=view,
mode=mode,
),
quote_via=quote,
)
dr = load_dataset(dataset)
match mode:
case m if m.endswith("table"):
df = get_table(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
dt = get_DataTable(df)
app_content = [] if dt is None else [dt]
case _:
fig = get_fig(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
g = get_Graph(fig)
app_content = [] if g is None else [g]
return app_content, f"?{search}"
def run():