diff --git a/qcdash/app.py b/qcdash/app.py index 3a67a52..bce90f4 100644 --- a/qcdash/app.py +++ b/qcdash/app.py @@ -1,17 +1,21 @@ +import json import os from collections import defaultdict +from json import JSONDecodeError from pathlib import Path from typing import List +from urllib.parse import parse_qsl, quote, urlencode, urlparse import dash_bootstrap_components as dbc +import flask import numpy as np -import pandas as pd -import plotly.graph_objects as go -from dash import Dash, Input, Output, State, callback, dash_table, dcc, html +from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html +from dash.dash_table.Format import Format, Scheme from quacc import plot from quacc.evaluation.estimators import CE from quacc.evaluation.report import CompReport, DatasetReport +from quacc.evaluation.stats import ttest_rel backend = plot.get_backend("plotly") @@ -39,8 +43,17 @@ def get_datasets(root: str | Path) -> List[DatasetReport]: return dr_paths - dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)) - return {str(drp.parent): DatasetReport.unpickle(drp) for drp in dr_paths} + return list( + map( + lambda p: str(p.parent), + sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)), + ) + ) + + +def load_dataset(dataset): + dataset = Path(dataset) + return DatasetReport.unpickle(dataset / f"{dataset.name}.pickle") def get_fig(dr: DatasetReport, metric, estimators, view, mode): @@ -67,10 +80,102 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode): ) -datasets = get_datasets("output") +def get_table(dr: DatasetReport, metric, estimators, view, mode): + estimators = CE.name[estimators] + _prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] + match (view, mode): + case ("avg", "train_table"): + return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() + case ("avg", "test_table"): + return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + case ("avg", "shift_table"): + return ( + dr.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + case ("avg", "stats_table"): + return ttest_rel(dr, metric=metric, estimators=estimators) + case (_, "train_table"): + cr = dr.crs[_prevs.index(view)] + return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + case (_, "shift_table"): + cr = dr.crs[_prevs.index(view)] + return ( + cr.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + + +def get_DataTable(df): + _primary = "#0d6efd" + if df.empty: + return None + + df = df.reset_index() + columns = { + c: dict( + id=c, + name=c, + type="numeric", + format=Format(precision=6, scheme=Scheme.exponent), + ) + for c in df.columns + } + columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed) + columns = list(columns.values()) + data = df.to_dict("records") + + return html.Div( + [ + dash_table.DataTable( + data=data, + columns=columns, + id="table1", + style_cell={ + "padding": "0 12px", + "border": "0", + "border-bottom": f"1px solid {_primary}", + }, + style_table={ + "margin": "6vh 15px", + "padding": "15px", + "maxWidth": "80vw", + "overflowX": "auto", + "border": f"0px solid {_primary}", + "border-radius": "6px", + }, + ) + ], + style={ + "display": "flex", + "flex-direction": "column", + # "justify-content": "center", + "align-items": "center", + "height": "100vh", + }, + ) + + +def get_Graph(fig): + if fig is None: + return None + + return dcc.Graph( + id="graph1", + figure=fig, + style={ + "margin": 0, + "height": "100vh", + }, + ) + + +# datasets = get_datasets("output") app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) - +# app.config.suppress_callback_exceptions = True sidebar_style = { "top": 0, "left": 0, @@ -82,22 +187,25 @@ sidebar_style = { } content_style = { - # "margin-left": "18vw", - "flex": 9, + "flex": 10, } -sidebar = html.Div( - children=[ +def parse_href(href: str): + parse_result = urlparse(href) + params = parse_qsl(parse_result.query) + return dict(params) + + +def get_sidebar(): + return [ html.H4("Parameters:", style={"margin-bottom": "1vw"}), dbc.Select( - options=list(datasets.keys()), - value=list(datasets.keys())[0], + # options=list(datasets.keys()), + # value=list(datasets.keys())[0], id="dataset", ), dbc.Select( - # clearable=False, - # searchable=False, id="metric", style={"margin-top": "1vh"}, ), @@ -132,35 +240,57 @@ sidebar = html.Div( ], className="radio-group-wide", ), - ], - style=sidebar_style, - id="app-sidebar", -) + ] -content = html.Div( - children=[ - dcc.Graph( - style={"margin": 0, "height": "100vh"}, - id="graph1", - ), - ], - style=content_style, -) app.layout = html.Div( - children=[sidebar, content], - style={"display": "flex", "flexDirection": "row"}, + [ + dcc.Location(id="url", refresh=False), + html.Div( + [ + html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style), + html.Div(id="app_content", style=content_style), + ], + id="page_layout", + style={"display": "flex", "flexDirection": "row"}, + ), + ] ) +def apply_param(href, triggered_id, id, curr): + match triggered_id: + case "url": + params = parse_href(href) + return params.get(id, None) + case _: + return curr + + +@callback( + Output("dataset", "value"), + Output("dataset", "options"), + Input("url", "href"), +) +def update_dataset(href): + datasets = get_datasets("output") + + params = parse_href(href) + old_dataset = params.get("dataset", None) + new_dataset = old_dataset if old_dataset in datasets else datasets[0] + return new_dataset, datasets + + @callback( Output("metric", "options"), Output("metric", "value"), + Input("url", "href"), Input("dataset", "value"), State("metric", "value"), ) -def update_metrics(dataset, old_metric): - dr = datasets[dataset] +def update_metrics(href, dataset, curr_metric): + dr = load_dataset(dataset) + old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric) valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")] new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0] return valid_metrics, new_metric @@ -169,12 +299,19 @@ def update_metrics(dataset, old_metric): @callback( Output("estimators", "options"), Output("estimators", "value"), + Input("url", "href"), Input("dataset", "value"), Input("metric", "value"), State("estimators", "value"), ) -def update_estimators(dataset, metric, old_estimators): - dr = datasets[dataset] +def update_estimators(href, dataset, metric, curr_estimators): + dr = load_dataset(dataset) + old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators) + if isinstance(old_estimators, str): + try: + old_estimators = json.loads(old_estimators) + except JSONDecodeError: + old_estimators = [] valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy() new_estimators = valid_estimators[ np.isin(valid_estimators, old_estimators) @@ -185,11 +322,13 @@ def update_estimators(dataset, metric, old_estimators): @callback( Output("view", "options"), Output("view", "value"), + Input("url", "href"), Input("dataset", "value"), State("view", "value"), ) -def update_view(dataset, old_view): - dr = datasets[dataset] +def update_view(href, dataset, curr_view): + dr = load_dataset(dataset) + old_view = apply_param(href, ctx.triggered_id, "view", curr_view) valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] new_view = old_view if old_view in valid_views else valid_views[0] return valid_views, new_view @@ -198,26 +337,61 @@ def update_view(dataset, old_view): @callback( Output("mode", "options"), Output("mode", "value"), + Input("url", "href"), Input("view", "value"), State("mode", "value"), ) -def update_mode(view, old_mode): - valid_modes = [m for m in valid_plot_modes[view] if not m.endswith("table")] +def update_mode(href, view, curr_mode): + old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode) + valid_modes = valid_plot_modes[view] new_mode = old_mode if old_mode in valid_modes else valid_modes[0] return valid_modes, new_mode @callback( - Output("graph1", "figure"), + Output("app_content", "children"), + Output("url", "search"), Input("dataset", "value"), Input("metric", "value"), Input("estimators", "value"), Input("view", "value"), Input("mode", "value"), ) -def update_graph(dataset, metric, estimators, view, mode): - dr = datasets[dataset] - return get_fig(dr=dr, metric=metric, estimators=estimators, view=view, mode=mode) +def update_content(dataset, metric, estimators, view, mode): + search = urlencode( + dict( + dataset=dataset, + metric=metric, + estimators=json.dumps(estimators), + view=view, + mode=mode, + ), + quote_via=quote, + ) + dr = load_dataset(dataset) + match mode: + case m if m.endswith("table"): + df = get_table( + dr=dr, + metric=metric, + estimators=estimators, + view=view, + mode=mode, + ) + dt = get_DataTable(df) + app_content = [] if dt is None else [dt] + case _: + fig = get_fig( + dr=dr, + metric=metric, + estimators=estimators, + view=view, + mode=mode, + ) + g = get_Graph(fig) + app_content = [] if g is None else [g] + + return app_content, f"?{search}" def run():