plotly methods fixed, plot saving implemented
This commit is contained in:
parent
43056e76a8
commit
a3ffd689b1
|
|
@ -1,9 +1,11 @@
|
||||||
import os
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import plotly
|
import plotly
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
|
from quacc.utils.commons import get_plots_path
|
||||||
|
|
||||||
MODE = "lines"
|
MODE = "lines"
|
||||||
L_WIDTH = 5
|
L_WIDTH = 5
|
||||||
LEGEND = {
|
LEGEND = {
|
||||||
|
|
@ -16,12 +18,14 @@ FONT = {"size": 24}
|
||||||
TEMPLATE = "ggplot2"
|
TEMPLATE = "ggplot2"
|
||||||
|
|
||||||
|
|
||||||
def _save_or_return(fig, basedir, dataset_name, measure_name, plot_type):
|
def _save_or_return(
|
||||||
if basedir is not None:
|
fig: go.Figure, basedir, cls_name, acc_name, dataset_name, plot_type
|
||||||
plotsubdir = dataset_name
|
) -> go.Figure | None:
|
||||||
os.path.join(basedir, "plots", measure_name, plotsubdir, plot_type + ".svg")
|
if basedir is None:
|
||||||
|
return fig
|
||||||
|
|
||||||
return fig
|
path = get_plots_path(basedir, cls_name, acc_name, dataset_name, plot_type)
|
||||||
|
fig.write_image(path)
|
||||||
|
|
||||||
|
|
||||||
def _update_layout(fig, title, x_label, y_label, **kwargs):
|
def _update_layout(fig, title, x_label, y_label, **kwargs):
|
||||||
|
|
@ -72,9 +76,10 @@ def plot_diagonal(
|
||||||
method_names,
|
method_names,
|
||||||
true_accs,
|
true_accs,
|
||||||
estim_accs,
|
estim_accs,
|
||||||
|
cls_name,
|
||||||
|
acc_name,
|
||||||
|
dataset_name,
|
||||||
*,
|
*,
|
||||||
measure_name="vanilla_accuracy",
|
|
||||||
dataset_name=None,
|
|
||||||
basedir=None,
|
basedir=None,
|
||||||
) -> go.Figure:
|
) -> go.Figure:
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
|
|
@ -111,8 +116,8 @@ def plot_diagonal(
|
||||||
|
|
||||||
_update_layout(
|
_update_layout(
|
||||||
fig,
|
fig,
|
||||||
x_label=f"True {measure_name}",
|
x_label=f"True {acc_name}",
|
||||||
y_label=f"Estimated {measure_name}",
|
y_label=f"Estimated {acc_name}",
|
||||||
autosize=False,
|
autosize=False,
|
||||||
width=1300,
|
width=1300,
|
||||||
height=1000,
|
height=1000,
|
||||||
|
|
@ -120,18 +125,19 @@ def plot_diagonal(
|
||||||
yaxis_scaleratio=1.0,
|
yaxis_scaleratio=1.0,
|
||||||
yaxis_range=[-0.1, 1.1],
|
yaxis_range=[-0.1, 1.1],
|
||||||
)
|
)
|
||||||
return _save_or_return(fig, basedir, dataset_name, measure_name, "diagonal")
|
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
|
||||||
|
|
||||||
|
|
||||||
def plot_delta(
|
def plot_delta(
|
||||||
method_names: list[str],
|
method_names: list[str],
|
||||||
prevs: np.ndarray,
|
prevs: np.ndarray,
|
||||||
acc_errs: np.ndarray,
|
acc_errs: np.ndarray,
|
||||||
|
cls_name,
|
||||||
|
acc_mame,
|
||||||
|
dataset_name,
|
||||||
|
prev_name,
|
||||||
*,
|
*,
|
||||||
stdevs: np.ndarray | None = None,
|
stdevs: np.ndarray | None = None,
|
||||||
prev_name="Test",
|
|
||||||
measure_name="Vanilla Accuracy",
|
|
||||||
dataset_name=None,
|
|
||||||
basedir=None,
|
basedir=None,
|
||||||
) -> go.Figure:
|
) -> go.Figure:
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
|
|
@ -170,10 +176,15 @@ def plot_delta(
|
||||||
_update_layout(
|
_update_layout(
|
||||||
fig,
|
fig,
|
||||||
x_label=f"{prev_name} Prevalence",
|
x_label=f"{prev_name} Prevalence",
|
||||||
y_label=f"Prediction Error for {measure_name}",
|
y_label=f"Prediction Error for {acc_mame}",
|
||||||
)
|
)
|
||||||
return _save_or_return(
|
return _save_or_return(
|
||||||
fig, basedir, dataset_name, measure_name, "delta" if stdevs is None else "stdev"
|
fig,
|
||||||
|
basedir,
|
||||||
|
cls_name,
|
||||||
|
acc_mame,
|
||||||
|
dataset_name,
|
||||||
|
"delta" if stdevs is None else "stdev",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,10 +192,11 @@ def plot_shift(
|
||||||
method_names: list[str],
|
method_names: list[str],
|
||||||
prevs: np.ndarray,
|
prevs: np.ndarray,
|
||||||
acc_errs: np.ndarray,
|
acc_errs: np.ndarray,
|
||||||
|
cls_name,
|
||||||
|
acc_name,
|
||||||
|
dataset_name,
|
||||||
*,
|
*,
|
||||||
counts: np.ndarray | None = None,
|
counts: np.ndarray | None = None,
|
||||||
measure_name="Vanilla Accuracy",
|
|
||||||
dataset_name=None,
|
|
||||||
basedir=None,
|
basedir=None,
|
||||||
) -> go.Figure:
|
) -> go.Figure:
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
|
|
@ -212,6 +224,6 @@ def plot_shift(
|
||||||
_update_layout(
|
_update_layout(
|
||||||
fig,
|
fig,
|
||||||
x_label="Amount of Prior Probability Shift",
|
x_label="Amount of Prior Probability Shift",
|
||||||
y_label=f"Prediction Error for {measure_name}",
|
y_label=f"Prediction Error for {acc_name}",
|
||||||
)
|
)
|
||||||
return _save_or_return(fig, basedir, dataset_name, measure_name, "shift")
|
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue