plots refactoring started

This commit is contained in:
Lorenzo Volpi 2024-04-04 17:03:52 +02:00
parent 4a06c83256
commit b8e43c02f2
2 changed files with 189 additions and 302 deletions

View File

@ -1,7 +1,7 @@
from quacc.plot.plot import ( from quacc.legacy.plot.plot import (
get_backend, get_backend,
plot_delta, plot_delta,
plot_diagonal, plot_diagonal,
plot_shift,
plot_fit_scores, plot_fit_scores,
plot_shift,
) )

View File

@ -1,330 +1,217 @@
from collections import defaultdict import os
from pathlib import Path
import numpy as np import numpy as np
import plotly import plotly
import plotly.graph_objects as go import plotly.graph_objects as go
from quacc.evaluation.estimators import CE, _renames MODE = "lines"
from quacc.plot.base import BasePlot L_WIDTH = 5
LEGEND = {
"font": {
class PlotCfg: "family": "DejaVu Sans",
def __init__(self, mode, lwidth, font=None, legend=None, template="seaborn"): "size": 24,
self.mode = mode
self.lwidth = lwidth
self.legend = {} if legend is None else legend
self.font = {} if font is None else font
self.template = template
web_cfg = PlotCfg("lines+markers", 2)
png_cfg_old = PlotCfg(
"lines",
5,
legend=dict(
orientation="h",
yanchor="bottom",
xanchor="right",
y=1.02,
x=1,
font=dict(size=24),
),
font=dict(size=24),
# template="ggplot2",
)
png_cfg = PlotCfg(
"lines",
5,
legend=dict(
font=dict(
family="DejaVu Sans",
size=24,
),
),
font=dict(size=24),
# template="ggplot2",
)
_cfg = png_cfg
class PlotlyPlot(BasePlot):
__themes = defaultdict(
lambda: {
"template": _cfg.template,
}
)
__themes = __themes | {
"dark": {
"template": "plotly_dark",
},
} }
}
FONT = {"size": 24}
TEMPLATE = "ggplot2"
def __init__(self, theme=None):
self.theme = PlotlyPlot.__themes[theme]
self.rename = True
def hex_to_rgb(self, hex: str, t: float | None = None): def _save_or_return(fig, basedir, dataset_name, measure_name, plot_type):
hex = hex.lstrip("#") if basedir is not None:
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]] plotsubdir = dataset_name
if t is not None: os.path.join(basedir, "plots", measure_name, plotsubdir, plot_type + ".svg")
rgb.append(t)
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
def get_colors(self, num): return fig
match num:
case v if v > 10:
__colors = plotly.colors.qualitative.Light24
case _:
__colors = plotly.colors.qualitative.G10
def __generator(cs):
while True:
for c in cs:
yield c
return __generator(__colors) def _update_layout(fig, title, x_label, y_label, **kwargs):
fig.update_layout(
xaxis_title=x_label,
yaxis_title=y_label,
template=TEMPLATE,
font=FONT,
legend=LEGEND,
**kwargs,
)
def update_layout(self, fig, title, x_label, y_label):
fig.update_layout(
# title=title,
xaxis_title=x_label,
yaxis_title=y_label,
template=self.theme["template"],
font=_cfg.font,
legend=_cfg.legend,
)
def save_fig(self, fig, base_path, title) -> Path: def _hex_to_rgb(self, hex: str, t: float | None = None):
return None hex = hex.lstrip("#")
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
if t is not None:
rgb.append(t)
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
def rename_plots(
self,
columns,
):
if not self.rename:
return columns
new_columns = [] def _get_colors(self, num):
for c in columns: match num:
nc = c case v if v > 10:
for old, new in _renames.items(): __colors = plotly.colors.qualitative.Light24
if c.startswith(old): case _:
nc = new + c[len(old) :] __colors = plotly.colors.qualitative.G10
new_columns.append(nc) def __generator(cs):
while True:
for c in cs:
yield c
return np.array(new_columns) return __generator(__colors)
def plot_delta(
self, def _get_ref_limits(true_accs: np.ndarray, estim_accs: dict[str, np.ndarray]):
base_prevs, """get lmits of reference line"""
columns,
data, _edges = (
*, np.min([np.min(true_accs), np.min(estim_accs)]),
stdevs=None, np.max([np.max(true_accs), np.max(estim_accs)]),
pos_class=1, )
title="default", _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
x_label="prevs.",
y_label="error",
legend=True, def plot_diagonal(
) -> go.Figure: method_names,
fig = go.Figure() true_accs,
if isinstance(base_prevs[0], float): estim_accs,
base_prevs = np.around([(1 - bp, bp) for bp in base_prevs], decimals=4) *,
x = [str(tuple(bp)) for bp in base_prevs] measure_name="vanilla_accuracy",
named_data = {c: d for c, d in zip(columns, data)} dataset_name=None,
r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))} basedir=None,
line_colors = self.get_colors(len(columns)) ) -> go.Figure:
# for name, delta in zip(columns, data): fig = go.Figure()
columns = np.array(CE.name.sort(columns)) x = true_accs
for name in columns: line_colors = _get_colors(len(method_names))
delta = named_data[name] _lims = _get_ref_limits(true_accs, estim_accs)
r_name = r_columns[name]
color = next(line_colors) for name, estim in zip(method_names, estim_accs):
_line = [ color = next(line_colors)
slope, interc = np.polyfit(x, estim, 1)
fig.add_traces(
[
go.Scatter( go.Scatter(
x=x, x=x,
y=delta, y=estim,
mode=_cfg.mode, customdata=np.stack((estim - x,), axis=-1),
name=r_name, mode="markers",
line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth), name=name,
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}", marker=dict(color=_hex_to_rgb(color, t=0.5)),
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
),
]
)
fig.add_trace(
go.Scatter(
x=_lims[0],
y=_lims[1],
mode="lines",
name="reference",
showlegend=False,
line=dict(color=_hex_to_rgb("#000000"), dash="dash"),
)
)
_update_layout(
fig,
x_label=f"True {measure_name}",
y_label=f"Estimated {measure_name}",
autosize=False,
width=1300,
height=1000,
yaxis_scaleanchor="x",
yaxis_scaleratio=1.0,
yaxis_range=[-0.1, 1.1],
)
return _save_or_return(fig, basedir, dataset_name, measure_name, "diagonal")
def plot_delta(
method_names: list[str],
prevs: np.ndarray,
acc_errs: np.ndarray,
*,
stdevs: np.ndarray | None = None,
prev_name="Test",
measure_name="Vanilla Accuracy",
dataset_name=None,
basedir=None,
) -> go.Figure:
fig = go.Figure()
x = [str(bp) for bp in prevs]
line_colors = _get_colors(len(method_names))
if stdevs is None:
stdevs = [None] * len(method_names)
for name, delta, stdev in zip(method_names, acc_errs, stdevs):
color = next(line_colors)
_line = [
go.Scatter(
x=x,
y=delta,
mode=MODE,
name=name,
line=dict(color=_hex_to_rgb(color), width=L_WIDTH),
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
)
]
_error = []
if stdev is not None:
_error = [
go.Scatter(
x=np.concatenate([x, x[::-1]]),
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
name=name,
fill="toself",
fillcolor=_hex_to_rgb(color, t=0.2),
line=dict(color="rgba(255, 255, 255, 0)"),
hoverinfo="skip",
showlegend=False,
) )
] ]
_error = [] fig.add_traces(_line + _error)
if stdevs is not None:
_col_idx = np.where(columns == name)[0]
stdev = stdevs[_col_idx].flatten()
_error = [
go.Scatter(
x=np.concatenate([x, x[::-1]]),
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
name=int(_col_idx[0]),
fill="toself",
fillcolor=self.hex_to_rgb(color, t=0.2),
line=dict(color="rgba(255, 255, 255, 0)"),
hoverinfo="skip",
showlegend=False,
)
]
fig.add_traces(_line + _error)
self.update_layout(fig, title, x_label, y_label) _update_layout(
return fig fig,
x_label=f"{prev_name} Prevalence",
y_label=f"Prediction Error for {measure_name}",
)
return _save_or_return(
fig, basedir, dataset_name, measure_name, "delta" if stdevs is None else "stdev"
)
def plot_diagonal(
self,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
fixed_lim=False,
legend=True,
) -> go.Figure:
fig = go.Figure()
x = reference
line_colors = self.get_colors(len(columns))
if fixed_lim: def plot_shift(
_lims = np.array([[0.0, 1.0], [0.0, 1.0]]) method_names: list[str],
else: prevs: np.ndarray,
_edges = ( acc_errs: np.ndarray,
np.min([np.min(x), np.min(data)]), *,
np.max([np.max(x), np.max(data)]), counts: np.ndarray | None = None,
) measure_name="Vanilla Accuracy",
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) dataset_name=None,
basedir=None,
named_data = {c: d for c, d in zip(columns, data)} ) -> go.Figure:
r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))} fig = go.Figure()
columns = np.array(CE.name.sort(columns)) x = prevs
for name in columns: line_colors = _get_colors(len(method_names))
val = named_data[name] if counts is None:
r_name = r_columns[name] counts = [None] * len(method_names)
color = next(line_colors) for name, delta, count in zip(method_names, acc_errs, counts):
slope, interc = np.polyfit(x, val, 1) color = next(line_colors)
# y_lr = np.array([slope * _x + interc for _x in _lims[0]])
fig.add_traces(
[
go.Scatter(
x=x,
y=val,
customdata=np.stack((val - x,), axis=-1),
mode="markers",
name=r_name,
marker=dict(color=self.hex_to_rgb(color, t=0.5)),
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
# showlegend=False,
),
# go.Scatter(
# x=[x[-1]],
# y=[val[-1]],
# mode="markers",
# marker=dict(color=self.hex_to_rgb(color), size=8),
# name=r_name,
# ),
# go.Scatter(
# x=_lims[0],
# y=y_lr,
# mode="lines",
# name=name,
# line=dict(color=self.hex_to_rgb(color), width=3),
# showlegend=False,
# ),
]
)
fig.add_trace(
go.Scatter(
x=_lims[0],
y=_lims[1],
mode="lines",
name="reference",
showlegend=False,
line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
)
)
self.update_layout(fig, title, x_label, y_label)
fig.update_layout(
autosize=False,
width=1300,
height=1000,
yaxis_scaleanchor="x",
yaxis_scaleratio=1.0,
yaxis_range=[-0.1, 1.1],
)
return fig
def plot_shift(
self,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
) -> go.Figure:
fig = go.Figure()
# x = shift_prevs[:, pos_class]
x = shift_prevs
line_colors = self.get_colors(len(columns))
named_data = {c: d for c, d in zip(columns, data)}
r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))}
columns = np.array(CE.name.sort(columns))
for name in columns:
delta = named_data[name]
r_name = r_columns[name]
col_idx = (columns == name).nonzero()[0][0]
color = next(line_colors)
fig.add_trace(
go.Scatter(
x=x,
y=delta,
customdata=np.stack((counts[col_idx],), axis=-1),
mode=_cfg.mode,
name=r_name,
line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth),
hovertemplate="shift: %{x}<br>error: %{y}"
+ "<br>count: %{customdata[0]}"
if counts is not None
else "",
)
)
self.update_layout(fig, title, x_label, y_label)
return fig
def plot_fit_scores(
self,
train_prevs,
scores,
*,
pos_class=1,
title="default",
x_label="prev.",
y_label="position",
legend=True,
) -> go.Figure:
fig = go.Figure()
# x = train_prevs
x = [str(tuple(bp)) for bp in train_prevs]
fig.add_trace( fig.add_trace(
go.Scatter( go.Scatter(
x=x, x=x,
y=scores, y=delta,
mode="lines+markers", customdata=np.stack((count,), axis=-1),
showlegend=False, mode=MODE,
), name=name,
line=dict(color=_hex_to_rgb(color), width=L_WIDTH),
hovertemplate="shift: %{x}<br>error: %{y}"
+ "<br>count: %{customdata[0]}"
if count is not None
else "",
)
) )
self.update_layout(fig, title, x_label, y_label) _update_layout(
return fig fig,
x_label="Amount of Prior Probability Shift",
y_label=f"Prediction Error for {measure_name}",
)
return _save_or_return(fig, basedir, dataset_name, measure_name, "shift")