diff --git a/qcdash/app.py b/qcdash/app.py index bce90f4..d91dae4 100644 --- a/qcdash/app.py +++ b/qcdash/app.py @@ -7,7 +7,6 @@ from typing import List from urllib.parse import parse_qsl, quote, urlencode, urlparse import dash_bootstrap_components as dbc -import flask import numpy as np from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html from dash.dash_table.Format import Format, Scheme @@ -24,6 +23,10 @@ valid_plot_modes["avg"] = DatasetReport._default_dr_modes def get_datasets(root: str | Path) -> List[DatasetReport]: + def load_dataset(dataset): + dataset = Path(dataset) + return DatasetReport.unpickle(dataset) + def explore_datasets(root: str | Path) -> List[Path]: if isinstance(root, str): root = Path(root) @@ -43,17 +46,8 @@ def get_datasets(root: str | Path) -> List[DatasetReport]: return dr_paths - return list( - map( - lambda p: str(p.parent), - sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)), - ) - ) - - -def load_dataset(dataset): - dataset = Path(dataset) - return DatasetReport.unpickle(dataset / f"{dataset.name}.pickle") + dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)) + return {str(drp.parent): load_dataset(drp) for drp in dr_paths} def get_fig(dr: DatasetReport, metric, estimators, view, mode): @@ -172,7 +166,7 @@ def get_Graph(fig): ) -# datasets = get_datasets("output") +datasets = get_datasets("output") app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) # app.config.suppress_callback_exceptions = True @@ -183,11 +177,14 @@ sidebar_style = { "padding": "1vw", "padding-top": "2vw", "margin": "0px", - "flex": 2, + "flex": 1, + "overflow": "scroll", + "height": "100vh", } content_style = { - "flex": 10, + "flex": 5, + "maxWidth": "84vw", } @@ -245,6 +242,7 @@ def get_sidebar(): app.layout = html.Div( [ + dcc.Interval(id="reload", interval=10 * 60 * 1000), dcc.Location(id="url", refresh=False), html.Div( [ @@ -257,6 +255,8 @@ app.layout = html.Div( ] ) +server = app.server + def apply_param(href, triggered_id, id, curr): match triggered_id: @@ -271,14 +271,25 @@ def apply_param(href, triggered_id, id, curr): Output("dataset", "value"), Output("dataset", "options"), Input("url", "href"), + Input("reload", "n_intervals"), + State("dataset", "value"), ) -def update_dataset(href): - datasets = get_datasets("output") +def update_dataset(href, n_intervals, dataset): + match ctx.triggered_id: + case "reload": + new_datasets = get_datasets("output") + global datasets + datasets = new_datasets + req_dataset = dataset + case "url": + params = parse_href(href) + req_dataset = params.get("dataset", None) - params = parse_href(href) - old_dataset = params.get("dataset", None) - new_dataset = old_dataset if old_dataset in datasets else datasets[0] - return new_dataset, datasets + available_datasets = list(datasets.keys()) + new_dataset = ( + req_dataset if req_dataset in available_datasets else available_datasets[0] + ) + return new_dataset, available_datasets @callback( @@ -289,7 +300,7 @@ def update_dataset(href): State("metric", "value"), ) def update_metrics(href, dataset, curr_metric): - dr = load_dataset(dataset) + dr = datasets[dataset] old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric) valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")] new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0] @@ -305,7 +316,7 @@ def update_metrics(href, dataset, curr_metric): State("estimators", "value"), ) def update_estimators(href, dataset, metric, curr_estimators): - dr = load_dataset(dataset) + dr = datasets[dataset] old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators) if isinstance(old_estimators, str): try: @@ -327,7 +338,7 @@ def update_estimators(href, dataset, metric, curr_estimators): State("view", "value"), ) def update_view(href, dataset, curr_view): - dr = load_dataset(dataset) + dr = datasets[dataset] old_view = apply_param(href, ctx.triggered_id, "view", curr_view) valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] new_view = old_view if old_view in valid_views else valid_views[0] @@ -368,7 +379,7 @@ def update_content(dataset, metric, estimators, view, mode): ), quote_via=quote, ) - dr = load_dataset(dataset) + dr = datasets[dataset] match mode: case m if m.endswith("table"): df = get_table(