diff --git a/quacc/plot/__init__.py b/quacc/plot/__init__.py
index c16c75e..433e7e1 100644
--- a/quacc/plot/__init__.py
+++ b/quacc/plot/__init__.py
@@ -1,7 +1,7 @@
-from quacc.plot.plot import (
+from quacc.legacy.plot.plot import (
get_backend,
plot_delta,
plot_diagonal,
- plot_shift,
plot_fit_scores,
+ plot_shift,
)
diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py
index 52a514d..2fb6978 100644
--- a/quacc/plot/plotly.py
+++ b/quacc/plot/plotly.py
@@ -1,330 +1,217 @@
-from collections import defaultdict
-from pathlib import Path
+import os
import numpy as np
import plotly
import plotly.graph_objects as go
-from quacc.evaluation.estimators import CE, _renames
-from quacc.plot.base import BasePlot
-
-
-class PlotCfg:
- def __init__(self, mode, lwidth, font=None, legend=None, template="seaborn"):
- self.mode = mode
- self.lwidth = lwidth
- self.legend = {} if legend is None else legend
- self.font = {} if font is None else font
- self.template = template
-
-
-web_cfg = PlotCfg("lines+markers", 2)
-png_cfg_old = PlotCfg(
- "lines",
- 5,
- legend=dict(
- orientation="h",
- yanchor="bottom",
- xanchor="right",
- y=1.02,
- x=1,
- font=dict(size=24),
- ),
- font=dict(size=24),
- # template="ggplot2",
-)
-png_cfg = PlotCfg(
- "lines",
- 5,
- legend=dict(
- font=dict(
- family="DejaVu Sans",
- size=24,
- ),
- ),
- font=dict(size=24),
- # template="ggplot2",
-)
-
-_cfg = png_cfg
-
-
-class PlotlyPlot(BasePlot):
- __themes = defaultdict(
- lambda: {
- "template": _cfg.template,
- }
- )
- __themes = __themes | {
- "dark": {
- "template": "plotly_dark",
- },
+MODE = "lines"
+L_WIDTH = 5
+LEGEND = {
+ "font": {
+ "family": "DejaVu Sans",
+ "size": 24,
}
+}
+FONT = {"size": 24}
+TEMPLATE = "ggplot2"
- def __init__(self, theme=None):
- self.theme = PlotlyPlot.__themes[theme]
- self.rename = True
- def hex_to_rgb(self, hex: str, t: float | None = None):
- hex = hex.lstrip("#")
- rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
- if t is not None:
- rgb.append(t)
- return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
+def _save_or_return(fig, basedir, dataset_name, measure_name, plot_type):
+ if basedir is not None:
+ plotsubdir = dataset_name
+ os.path.join(basedir, "plots", measure_name, plotsubdir, plot_type + ".svg")
- def get_colors(self, num):
- match num:
- case v if v > 10:
- __colors = plotly.colors.qualitative.Light24
- case _:
- __colors = plotly.colors.qualitative.G10
+ return fig
- def __generator(cs):
- while True:
- for c in cs:
- yield c
- return __generator(__colors)
+def _update_layout(fig, title, x_label, y_label, **kwargs):
+ fig.update_layout(
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ template=TEMPLATE,
+ font=FONT,
+ legend=LEGEND,
+ **kwargs,
+ )
- def update_layout(self, fig, title, x_label, y_label):
- fig.update_layout(
- # title=title,
- xaxis_title=x_label,
- yaxis_title=y_label,
- template=self.theme["template"],
- font=_cfg.font,
- legend=_cfg.legend,
- )
- def save_fig(self, fig, base_path, title) -> Path:
- return None
+def _hex_to_rgb(self, hex: str, t: float | None = None):
+ hex = hex.lstrip("#")
+ rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
+ if t is not None:
+ rgb.append(t)
+ return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
- def rename_plots(
- self,
- columns,
- ):
- if not self.rename:
- return columns
- new_columns = []
- for c in columns:
- nc = c
- for old, new in _renames.items():
- if c.startswith(old):
- nc = new + c[len(old) :]
+def _get_colors(self, num):
+ match num:
+ case v if v > 10:
+ __colors = plotly.colors.qualitative.Light24
+ case _:
+ __colors = plotly.colors.qualitative.G10
- new_columns.append(nc)
+ def __generator(cs):
+ while True:
+ for c in cs:
+ yield c
- return np.array(new_columns)
+ return __generator(__colors)
- def plot_delta(
- self,
- base_prevs,
- columns,
- data,
- *,
- stdevs=None,
- pos_class=1,
- title="default",
- x_label="prevs.",
- y_label="error",
- legend=True,
- ) -> go.Figure:
- fig = go.Figure()
- if isinstance(base_prevs[0], float):
- base_prevs = np.around([(1 - bp, bp) for bp in base_prevs], decimals=4)
- x = [str(tuple(bp)) for bp in base_prevs]
- named_data = {c: d for c, d in zip(columns, data)}
- r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))}
- line_colors = self.get_colors(len(columns))
- # for name, delta in zip(columns, data):
- columns = np.array(CE.name.sort(columns))
- for name in columns:
- delta = named_data[name]
- r_name = r_columns[name]
- color = next(line_colors)
- _line = [
+
+def _get_ref_limits(true_accs: np.ndarray, estim_accs: dict[str, np.ndarray]):
+ """get lmits of reference line"""
+
+ _edges = (
+ np.min([np.min(true_accs), np.min(estim_accs)]),
+ np.max([np.max(true_accs), np.max(estim_accs)]),
+ )
+ _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
+
+
+def plot_diagonal(
+ method_names,
+ true_accs,
+ estim_accs,
+ *,
+ measure_name="vanilla_accuracy",
+ dataset_name=None,
+ basedir=None,
+) -> go.Figure:
+ fig = go.Figure()
+ x = true_accs
+ line_colors = _get_colors(len(method_names))
+ _lims = _get_ref_limits(true_accs, estim_accs)
+
+ for name, estim in zip(method_names, estim_accs):
+ color = next(line_colors)
+ slope, interc = np.polyfit(x, estim, 1)
+ fig.add_traces(
+ [
go.Scatter(
x=x,
- y=delta,
- mode=_cfg.mode,
- name=r_name,
- line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth),
- hovertemplate="prev.: %{x}
error: %{y:,.4f}",
+ y=estim,
+ customdata=np.stack((estim - x,), axis=-1),
+ mode="markers",
+ name=name,
+ marker=dict(color=_hex_to_rgb(color, t=0.5)),
+ hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}",
+ ),
+ ]
+ )
+ fig.add_trace(
+ go.Scatter(
+ x=_lims[0],
+ y=_lims[1],
+ mode="lines",
+ name="reference",
+ showlegend=False,
+ line=dict(color=_hex_to_rgb("#000000"), dash="dash"),
+ )
+ )
+
+ _update_layout(
+ fig,
+ x_label=f"True {measure_name}",
+ y_label=f"Estimated {measure_name}",
+ autosize=False,
+ width=1300,
+ height=1000,
+ yaxis_scaleanchor="x",
+ yaxis_scaleratio=1.0,
+ yaxis_range=[-0.1, 1.1],
+ )
+ return _save_or_return(fig, basedir, dataset_name, measure_name, "diagonal")
+
+
+def plot_delta(
+ method_names: list[str],
+ prevs: np.ndarray,
+ acc_errs: np.ndarray,
+ *,
+ stdevs: np.ndarray | None = None,
+ prev_name="Test",
+ measure_name="Vanilla Accuracy",
+ dataset_name=None,
+ basedir=None,
+) -> go.Figure:
+ fig = go.Figure()
+ x = [str(bp) for bp in prevs]
+ line_colors = _get_colors(len(method_names))
+ if stdevs is None:
+ stdevs = [None] * len(method_names)
+ for name, delta, stdev in zip(method_names, acc_errs, stdevs):
+ color = next(line_colors)
+ _line = [
+ go.Scatter(
+ x=x,
+ y=delta,
+ mode=MODE,
+ name=name,
+ line=dict(color=_hex_to_rgb(color), width=L_WIDTH),
+ hovertemplate="prev.: %{x}
error: %{y:,.4f}",
+ )
+ ]
+ _error = []
+ if stdev is not None:
+ _error = [
+ go.Scatter(
+ x=np.concatenate([x, x[::-1]]),
+ y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
+ name=name,
+ fill="toself",
+ fillcolor=_hex_to_rgb(color, t=0.2),
+ line=dict(color="rgba(255, 255, 255, 0)"),
+ hoverinfo="skip",
+ showlegend=False,
)
]
- _error = []
- if stdevs is not None:
- _col_idx = np.where(columns == name)[0]
- stdev = stdevs[_col_idx].flatten()
- _error = [
- go.Scatter(
- x=np.concatenate([x, x[::-1]]),
- y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
- name=int(_col_idx[0]),
- fill="toself",
- fillcolor=self.hex_to_rgb(color, t=0.2),
- line=dict(color="rgba(255, 255, 255, 0)"),
- hoverinfo="skip",
- showlegend=False,
- )
- ]
- fig.add_traces(_line + _error)
+ fig.add_traces(_line + _error)
- self.update_layout(fig, title, x_label, y_label)
- return fig
+ _update_layout(
+ fig,
+ x_label=f"{prev_name} Prevalence",
+ y_label=f"Prediction Error for {measure_name}",
+ )
+ return _save_or_return(
+ fig, basedir, dataset_name, measure_name, "delta" if stdevs is None else "stdev"
+ )
- def plot_diagonal(
- self,
- reference,
- columns,
- data,
- *,
- pos_class=1,
- title="default",
- x_label="true",
- y_label="estim.",
- fixed_lim=False,
- legend=True,
- ) -> go.Figure:
- fig = go.Figure()
- x = reference
- line_colors = self.get_colors(len(columns))
- if fixed_lim:
- _lims = np.array([[0.0, 1.0], [0.0, 1.0]])
- else:
- _edges = (
- np.min([np.min(x), np.min(data)]),
- np.max([np.max(x), np.max(data)]),
- )
- _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
-
- named_data = {c: d for c, d in zip(columns, data)}
- r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))}
- columns = np.array(CE.name.sort(columns))
- for name in columns:
- val = named_data[name]
- r_name = r_columns[name]
- color = next(line_colors)
- slope, interc = np.polyfit(x, val, 1)
- # y_lr = np.array([slope * _x + interc for _x in _lims[0]])
- fig.add_traces(
- [
- go.Scatter(
- x=x,
- y=val,
- customdata=np.stack((val - x,), axis=-1),
- mode="markers",
- name=r_name,
- marker=dict(color=self.hex_to_rgb(color, t=0.5)),
- hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}",
- # showlegend=False,
- ),
- # go.Scatter(
- # x=[x[-1]],
- # y=[val[-1]],
- # mode="markers",
- # marker=dict(color=self.hex_to_rgb(color), size=8),
- # name=r_name,
- # ),
- # go.Scatter(
- # x=_lims[0],
- # y=y_lr,
- # mode="lines",
- # name=name,
- # line=dict(color=self.hex_to_rgb(color), width=3),
- # showlegend=False,
- # ),
- ]
- )
- fig.add_trace(
- go.Scatter(
- x=_lims[0],
- y=_lims[1],
- mode="lines",
- name="reference",
- showlegend=False,
- line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
- )
- )
-
- self.update_layout(fig, title, x_label, y_label)
- fig.update_layout(
- autosize=False,
- width=1300,
- height=1000,
- yaxis_scaleanchor="x",
- yaxis_scaleratio=1.0,
- yaxis_range=[-0.1, 1.1],
- )
- return fig
-
- def plot_shift(
- self,
- shift_prevs,
- columns,
- data,
- *,
- counts=None,
- pos_class=1,
- title="default",
- x_label="true",
- y_label="estim.",
- legend=True,
- ) -> go.Figure:
- fig = go.Figure()
- # x = shift_prevs[:, pos_class]
- x = shift_prevs
- line_colors = self.get_colors(len(columns))
- named_data = {c: d for c, d in zip(columns, data)}
- r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))}
- columns = np.array(CE.name.sort(columns))
- for name in columns:
- delta = named_data[name]
- r_name = r_columns[name]
- col_idx = (columns == name).nonzero()[0][0]
- color = next(line_colors)
- fig.add_trace(
- go.Scatter(
- x=x,
- y=delta,
- customdata=np.stack((counts[col_idx],), axis=-1),
- mode=_cfg.mode,
- name=r_name,
- line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth),
- hovertemplate="shift: %{x}
error: %{y}"
- + "
count: %{customdata[0]}"
- if counts is not None
- else "",
- )
- )
-
- self.update_layout(fig, title, x_label, y_label)
- return fig
-
- def plot_fit_scores(
- self,
- train_prevs,
- scores,
- *,
- pos_class=1,
- title="default",
- x_label="prev.",
- y_label="position",
- legend=True,
- ) -> go.Figure:
- fig = go.Figure()
- # x = train_prevs
- x = [str(tuple(bp)) for bp in train_prevs]
+def plot_shift(
+ method_names: list[str],
+ prevs: np.ndarray,
+ acc_errs: np.ndarray,
+ *,
+ counts: np.ndarray | None = None,
+ measure_name="Vanilla Accuracy",
+ dataset_name=None,
+ basedir=None,
+) -> go.Figure:
+ fig = go.Figure()
+ x = prevs
+ line_colors = _get_colors(len(method_names))
+ if counts is None:
+ counts = [None] * len(method_names)
+ for name, delta, count in zip(method_names, acc_errs, counts):
+ color = next(line_colors)
fig.add_trace(
go.Scatter(
x=x,
- y=scores,
- mode="lines+markers",
- showlegend=False,
- ),
+ y=delta,
+ customdata=np.stack((count,), axis=-1),
+ mode=MODE,
+ name=name,
+ line=dict(color=_hex_to_rgb(color), width=L_WIDTH),
+ hovertemplate="shift: %{x}
error: %{y}"
+ + "
count: %{customdata[0]}"
+ if count is not None
+ else "",
+ )
)
- self.update_layout(fig, title, x_label, y_label)
- return fig
+ _update_layout(
+ fig,
+ x_label="Amount of Prior Probability Shift",
+ y_label=f"Prediction Error for {measure_name}",
+ )
+ return _save_or_return(fig, basedir, dataset_name, measure_name, "shift")