diff --git a/qcdash/app.py b/qcdash/app.py index b1fbf67..81f2019 100644 --- a/qcdash/app.py +++ b/qcdash/app.py @@ -2,6 +2,7 @@ import json import os from collections import defaultdict from json import JSONDecodeError +from operator import index from pathlib import Path from typing import List from urllib.parse import parse_qsl, quote, urlencode, urlparse @@ -9,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlencode, urlparse import dash_bootstrap_components as dbc import numpy as np from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html -from dash.dash_table.Format import Format, Scheme +from dash.dash_table.Format import Align, Format, Scheme from quacc import plot from quacc.evaluation.estimators import CE @@ -21,6 +22,10 @@ valid_plot_modes["avg"] = DatasetReport._default_dr_modes root_folder = "output" +def _get_prev_str(prev: np.ndarray): + return str(tuple(np.around(prev, decimals=2))) + + def get_datasets(root: str | Path) -> List[DatasetReport]: def load_dataset(dataset): dataset = Path(dataset) @@ -63,7 +68,7 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None): backend=_backend, ) case (_, _): - cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)] + cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)] return cr.get_plots( mode=mode, metric=metric, @@ -76,53 +81,105 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None): def get_table(dr: DatasetReport, metric, estimators, view, mode): estimators = CE.name[estimators] - _prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] match (view, mode): case ("avg", "train_table"): - return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() + # return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() + return dr.train_table(metric=metric, estimators=estimators) case ("avg", "test_table"): - return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + # return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + return dr.test_table(metric=metric, estimators=estimators) case ("avg", "shift_table"): - return ( - dr.shift_data(metric=metric, estimators=estimators) - .groupby(level=0) - .mean() - ) + # return ( + # dr.shift_data(metric=metric, estimators=estimators) + # .groupby(level=0) + # .mean() + # ) + return dr.shift_table(metric=metric, estimators=estimators) case ("avg", "stats_table"): return wilcoxon(dr, metric=metric, estimators=estimators) case (_, "train_table"): - cr = dr.crs[_prevs.index(view)] - return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)] + # return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + return cr.train_table(metric=metric, estimators=estimators) case (_, "shift_table"): - cr = dr.crs[_prevs.index(view)] - return ( - cr.shift_data(metric=metric, estimators=estimators) - .groupby(level=0) - .mean() - ) + cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)] + # return ( + # cr.shift_data(metric=metric, estimators=estimators) + # .groupby(level=0) + # .mean() + # ) + return cr.shift_table(metric=metric, estimators=estimators) case (_, "stats_table"): - cr = dr.crs[_prevs.index(view)] + cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)] return wilcoxon(cr, metric=metric, estimators=estimators) -def get_DataTable(df): +def get_DataTable(df, mode): _primary = "#0d6efd" if df.empty: return None + _index_name = dict( + train_table="test prev.", + test_table="train prev.", + shift_table="shift", + stats_table="method", + ) df = df.reset_index() columns = { c: dict( id=c, - name=c, + name=_index_name[mode] if c == "index" else c, type="numeric", format=Format(precision=6, scheme=Scheme.exponent, nully="nan"), ) for c in df.columns } - columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed) + # columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed) + columns["index"]["format"] = Format() columns = list(columns.values()) data = df.to_dict("records") + for d in data: + if isinstance(d["index"], tuple | list | np.ndarray): + d["index"] = "(" + ", ".join([f"{v:.2f}" for v in d["index"]]) + ")" + elif isinstance(d["index"], float): + d["index"] = f"{d['index']:.2f}" + + _style_cell = { + "padding": "0 12px", + "border": "0", + "border-bottom": f"1px solid {_primary}", + } + + _style_cell_conditional = [ + { + "if": {"column_id": "index"}, + "text_align": "center", + }, + ] + + _style_data_conditional = [] + if mode != "stats_table": + _style_data_conditional += [ + { + "if": {"column_id": "index", "row_index": len(data) - 1}, + "font_weight": "bold", + }, + { + "if": {"row_index": len(data) - 1}, + "background_color": "#0d6efd", + "color": "white", + }, + ] + + _style_table = { + "margin": "6vh 15px", + "padding": "15px", + "maxWidth": "80vw", + "overflowX": "auto", + "border": f"0px solid {_primary}", + "border-radius": "6px", + } return html.Div( [ @@ -130,19 +187,10 @@ def get_DataTable(df): data=data, columns=columns, id="table1", - style_cell={ - "padding": "0 12px", - "border": "0", - "border-bottom": f"1px solid {_primary}", - }, - style_table={ - "margin": "6vh 15px", - "padding": "15px", - "maxWidth": "80vw", - "overflowX": "auto", - "border": f"0px solid {_primary}", - "border-radius": "6px", - }, + style_cell=_style_cell, + style_cell_conditional=_style_cell_conditional, + style_data_conditional=_style_data_conditional, + style_table=_style_table, ) ], style={ @@ -361,7 +409,7 @@ def update_estimators(href, dataset, metric, curr_estimators, root): def update_view(href, dataset, curr_view, root): dr = get_dr(root, dataset) old_view = apply_param(href, ctx.triggered_id, "view", curr_view) - valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] + valid_views = ["avg"] + [_get_prev_str(cr.train_prev) for cr in dr.crs] new_view = old_view if old_view in valid_views else valid_views[0] return valid_views, new_view @@ -412,7 +460,7 @@ def update_content(dataset, metric, estimators, view, mode, root): view=view, mode=mode, ) - dt = get_DataTable(df) + dt = get_DataTable(df, mode) app_content = [] if dt is None else [dt] case _: fig = get_fig(