dash app updated
This commit is contained in:
parent
5847c217ed
commit
186c3cb5f6
256
qcdash/app.py
256
qcdash/app.py
|
@ -1,17 +1,21 @@
|
|||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from urllib.parse import parse_qsl, quote, urlencode, urlparse
|
||||
|
||||
import dash_bootstrap_components as dbc
|
||||
import flask
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
from dash import Dash, Input, Output, State, callback, dash_table, dcc, html
|
||||
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
|
||||
from dash.dash_table.Format import Format, Scheme
|
||||
|
||||
from quacc import plot
|
||||
from quacc.evaluation.estimators import CE
|
||||
from quacc.evaluation.report import CompReport, DatasetReport
|
||||
from quacc.evaluation.stats import ttest_rel
|
||||
|
||||
backend = plot.get_backend("plotly")
|
||||
|
||||
|
@ -39,8 +43,17 @@ def get_datasets(root: str | Path) -> List[DatasetReport]:
|
|||
|
||||
return dr_paths
|
||||
|
||||
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
|
||||
return {str(drp.parent): DatasetReport.unpickle(drp) for drp in dr_paths}
|
||||
return list(
|
||||
map(
|
||||
lambda p: str(p.parent),
|
||||
sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(dataset):
|
||||
dataset = Path(dataset)
|
||||
return DatasetReport.unpickle(dataset / f"{dataset.name}.pickle")
|
||||
|
||||
|
||||
def get_fig(dr: DatasetReport, metric, estimators, view, mode):
|
||||
|
@ -67,10 +80,102 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode):
|
|||
)
|
||||
|
||||
|
||||
datasets = get_datasets("output")
|
||||
def get_table(dr: DatasetReport, metric, estimators, view, mode):
|
||||
estimators = CE.name[estimators]
|
||||
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
|
||||
match (view, mode):
|
||||
case ("avg", "train_table"):
|
||||
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
|
||||
case ("avg", "test_table"):
|
||||
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||
case ("avg", "shift_table"):
|
||||
return (
|
||||
dr.shift_data(metric=metric, estimators=estimators)
|
||||
.groupby(level=0)
|
||||
.mean()
|
||||
)
|
||||
case ("avg", "stats_table"):
|
||||
return ttest_rel(dr, metric=metric, estimators=estimators)
|
||||
case (_, "train_table"):
|
||||
cr = dr.crs[_prevs.index(view)]
|
||||
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||
case (_, "shift_table"):
|
||||
cr = dr.crs[_prevs.index(view)]
|
||||
return (
|
||||
cr.shift_data(metric=metric, estimators=estimators)
|
||||
.groupby(level=0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
|
||||
def get_DataTable(df):
|
||||
_primary = "#0d6efd"
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
df = df.reset_index()
|
||||
columns = {
|
||||
c: dict(
|
||||
id=c,
|
||||
name=c,
|
||||
type="numeric",
|
||||
format=Format(precision=6, scheme=Scheme.exponent),
|
||||
)
|
||||
for c in df.columns
|
||||
}
|
||||
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
|
||||
columns = list(columns.values())
|
||||
data = df.to_dict("records")
|
||||
|
||||
return html.Div(
|
||||
[
|
||||
dash_table.DataTable(
|
||||
data=data,
|
||||
columns=columns,
|
||||
id="table1",
|
||||
style_cell={
|
||||
"padding": "0 12px",
|
||||
"border": "0",
|
||||
"border-bottom": f"1px solid {_primary}",
|
||||
},
|
||||
style_table={
|
||||
"margin": "6vh 15px",
|
||||
"padding": "15px",
|
||||
"maxWidth": "80vw",
|
||||
"overflowX": "auto",
|
||||
"border": f"0px solid {_primary}",
|
||||
"border-radius": "6px",
|
||||
},
|
||||
)
|
||||
],
|
||||
style={
|
||||
"display": "flex",
|
||||
"flex-direction": "column",
|
||||
# "justify-content": "center",
|
||||
"align-items": "center",
|
||||
"height": "100vh",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_Graph(fig):
|
||||
if fig is None:
|
||||
return None
|
||||
|
||||
return dcc.Graph(
|
||||
id="graph1",
|
||||
figure=fig,
|
||||
style={
|
||||
"margin": 0,
|
||||
"height": "100vh",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# datasets = get_datasets("output")
|
||||
|
||||
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
|
||||
|
||||
# app.config.suppress_callback_exceptions = True
|
||||
sidebar_style = {
|
||||
"top": 0,
|
||||
"left": 0,
|
||||
|
@ -82,22 +187,25 @@ sidebar_style = {
|
|||
}
|
||||
|
||||
content_style = {
|
||||
# "margin-left": "18vw",
|
||||
"flex": 9,
|
||||
"flex": 10,
|
||||
}
|
||||
|
||||
|
||||
sidebar = html.Div(
|
||||
children=[
|
||||
def parse_href(href: str):
|
||||
parse_result = urlparse(href)
|
||||
params = parse_qsl(parse_result.query)
|
||||
return dict(params)
|
||||
|
||||
|
||||
def get_sidebar():
|
||||
return [
|
||||
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
|
||||
dbc.Select(
|
||||
options=list(datasets.keys()),
|
||||
value=list(datasets.keys())[0],
|
||||
# options=list(datasets.keys()),
|
||||
# value=list(datasets.keys())[0],
|
||||
id="dataset",
|
||||
),
|
||||
dbc.Select(
|
||||
# clearable=False,
|
||||
# searchable=False,
|
||||
id="metric",
|
||||
style={"margin-top": "1vh"},
|
||||
),
|
||||
|
@ -132,35 +240,57 @@ sidebar = html.Div(
|
|||
],
|
||||
className="radio-group-wide",
|
||||
),
|
||||
],
|
||||
style=sidebar_style,
|
||||
id="app-sidebar",
|
||||
)
|
||||
]
|
||||
|
||||
content = html.Div(
|
||||
children=[
|
||||
dcc.Graph(
|
||||
style={"margin": 0, "height": "100vh"},
|
||||
id="graph1",
|
||||
),
|
||||
],
|
||||
style=content_style,
|
||||
)
|
||||
|
||||
app.layout = html.Div(
|
||||
children=[sidebar, content],
|
||||
[
|
||||
dcc.Location(id="url", refresh=False),
|
||||
html.Div(
|
||||
[
|
||||
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
|
||||
html.Div(id="app_content", style=content_style),
|
||||
],
|
||||
id="page_layout",
|
||||
style={"display": "flex", "flexDirection": "row"},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def apply_param(href, triggered_id, id, curr):
|
||||
match triggered_id:
|
||||
case "url":
|
||||
params = parse_href(href)
|
||||
return params.get(id, None)
|
||||
case _:
|
||||
return curr
|
||||
|
||||
|
||||
@callback(
|
||||
Output("dataset", "value"),
|
||||
Output("dataset", "options"),
|
||||
Input("url", "href"),
|
||||
)
|
||||
def update_dataset(href):
|
||||
datasets = get_datasets("output")
|
||||
|
||||
params = parse_href(href)
|
||||
old_dataset = params.get("dataset", None)
|
||||
new_dataset = old_dataset if old_dataset in datasets else datasets[0]
|
||||
return new_dataset, datasets
|
||||
|
||||
|
||||
@callback(
|
||||
Output("metric", "options"),
|
||||
Output("metric", "value"),
|
||||
Input("url", "href"),
|
||||
Input("dataset", "value"),
|
||||
State("metric", "value"),
|
||||
)
|
||||
def update_metrics(dataset, old_metric):
|
||||
dr = datasets[dataset]
|
||||
def update_metrics(href, dataset, curr_metric):
|
||||
dr = load_dataset(dataset)
|
||||
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
|
||||
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
|
||||
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
|
||||
return valid_metrics, new_metric
|
||||
|
@ -169,12 +299,19 @@ def update_metrics(dataset, old_metric):
|
|||
@callback(
|
||||
Output("estimators", "options"),
|
||||
Output("estimators", "value"),
|
||||
Input("url", "href"),
|
||||
Input("dataset", "value"),
|
||||
Input("metric", "value"),
|
||||
State("estimators", "value"),
|
||||
)
|
||||
def update_estimators(dataset, metric, old_estimators):
|
||||
dr = datasets[dataset]
|
||||
def update_estimators(href, dataset, metric, curr_estimators):
|
||||
dr = load_dataset(dataset)
|
||||
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
|
||||
if isinstance(old_estimators, str):
|
||||
try:
|
||||
old_estimators = json.loads(old_estimators)
|
||||
except JSONDecodeError:
|
||||
old_estimators = []
|
||||
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
|
||||
new_estimators = valid_estimators[
|
||||
np.isin(valid_estimators, old_estimators)
|
||||
|
@ -185,11 +322,13 @@ def update_estimators(dataset, metric, old_estimators):
|
|||
@callback(
|
||||
Output("view", "options"),
|
||||
Output("view", "value"),
|
||||
Input("url", "href"),
|
||||
Input("dataset", "value"),
|
||||
State("view", "value"),
|
||||
)
|
||||
def update_view(dataset, old_view):
|
||||
dr = datasets[dataset]
|
||||
def update_view(href, dataset, curr_view):
|
||||
dr = load_dataset(dataset)
|
||||
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
|
||||
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
|
||||
new_view = old_view if old_view in valid_views else valid_views[0]
|
||||
return valid_views, new_view
|
||||
|
@ -198,26 +337,61 @@ def update_view(dataset, old_view):
|
|||
@callback(
|
||||
Output("mode", "options"),
|
||||
Output("mode", "value"),
|
||||
Input("url", "href"),
|
||||
Input("view", "value"),
|
||||
State("mode", "value"),
|
||||
)
|
||||
def update_mode(view, old_mode):
|
||||
valid_modes = [m for m in valid_plot_modes[view] if not m.endswith("table")]
|
||||
def update_mode(href, view, curr_mode):
|
||||
old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
|
||||
valid_modes = valid_plot_modes[view]
|
||||
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
|
||||
return valid_modes, new_mode
|
||||
|
||||
|
||||
@callback(
|
||||
Output("graph1", "figure"),
|
||||
Output("app_content", "children"),
|
||||
Output("url", "search"),
|
||||
Input("dataset", "value"),
|
||||
Input("metric", "value"),
|
||||
Input("estimators", "value"),
|
||||
Input("view", "value"),
|
||||
Input("mode", "value"),
|
||||
)
|
||||
def update_graph(dataset, metric, estimators, view, mode):
|
||||
dr = datasets[dataset]
|
||||
return get_fig(dr=dr, metric=metric, estimators=estimators, view=view, mode=mode)
|
||||
def update_content(dataset, metric, estimators, view, mode):
|
||||
search = urlencode(
|
||||
dict(
|
||||
dataset=dataset,
|
||||
metric=metric,
|
||||
estimators=json.dumps(estimators),
|
||||
view=view,
|
||||
mode=mode,
|
||||
),
|
||||
quote_via=quote,
|
||||
)
|
||||
dr = load_dataset(dataset)
|
||||
match mode:
|
||||
case m if m.endswith("table"):
|
||||
df = get_table(
|
||||
dr=dr,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
view=view,
|
||||
mode=mode,
|
||||
)
|
||||
dt = get_DataTable(df)
|
||||
app_content = [] if dt is None else [dt]
|
||||
case _:
|
||||
fig = get_fig(
|
||||
dr=dr,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
view=view,
|
||||
mode=mode,
|
||||
)
|
||||
g = get_Graph(fig)
|
||||
app_content = [] if g is None else [g]
|
||||
|
||||
return app_content, f"?{search}"
|
||||
|
||||
|
||||
def run():
|
||||
|
|
Loading…
Reference in New Issue