plotly methods fixed, plot saving implemented

This commit is contained in:
Lorenzo Volpi 2024-04-05 15:57:05 +02:00
parent 43056e76a8
commit a3ffd689b1
1 changed files with 32 additions and 20 deletions

View File

@ -1,9 +1,11 @@
import os from pathlib import Path
import numpy as np import numpy as np
import plotly import plotly
import plotly.graph_objects as go import plotly.graph_objects as go
from quacc.utils.commons import get_plots_path
MODE = "lines" MODE = "lines"
L_WIDTH = 5 L_WIDTH = 5
LEGEND = { LEGEND = {
@ -16,12 +18,14 @@ FONT = {"size": 24}
TEMPLATE = "ggplot2" TEMPLATE = "ggplot2"
def _save_or_return(fig, basedir, dataset_name, measure_name, plot_type): def _save_or_return(
if basedir is not None: fig: go.Figure, basedir, cls_name, acc_name, dataset_name, plot_type
plotsubdir = dataset_name ) -> go.Figure | None:
os.path.join(basedir, "plots", measure_name, plotsubdir, plot_type + ".svg") if basedir is None:
return fig
return fig path = get_plots_path(basedir, cls_name, acc_name, dataset_name, plot_type)
fig.write_image(path)
def _update_layout(fig, title, x_label, y_label, **kwargs): def _update_layout(fig, title, x_label, y_label, **kwargs):
@ -72,9 +76,10 @@ def plot_diagonal(
method_names, method_names,
true_accs, true_accs,
estim_accs, estim_accs,
cls_name,
acc_name,
dataset_name,
*, *,
measure_name="vanilla_accuracy",
dataset_name=None,
basedir=None, basedir=None,
) -> go.Figure: ) -> go.Figure:
fig = go.Figure() fig = go.Figure()
@ -111,8 +116,8 @@ def plot_diagonal(
_update_layout( _update_layout(
fig, fig,
x_label=f"True {measure_name}", x_label=f"True {acc_name}",
y_label=f"Estimated {measure_name}", y_label=f"Estimated {acc_name}",
autosize=False, autosize=False,
width=1300, width=1300,
height=1000, height=1000,
@ -120,18 +125,19 @@ def plot_diagonal(
yaxis_scaleratio=1.0, yaxis_scaleratio=1.0,
yaxis_range=[-0.1, 1.1], yaxis_range=[-0.1, 1.1],
) )
return _save_or_return(fig, basedir, dataset_name, measure_name, "diagonal") return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
def plot_delta( def plot_delta(
method_names: list[str], method_names: list[str],
prevs: np.ndarray, prevs: np.ndarray,
acc_errs: np.ndarray, acc_errs: np.ndarray,
cls_name,
acc_mame,
dataset_name,
prev_name,
*, *,
stdevs: np.ndarray | None = None, stdevs: np.ndarray | None = None,
prev_name="Test",
measure_name="Vanilla Accuracy",
dataset_name=None,
basedir=None, basedir=None,
) -> go.Figure: ) -> go.Figure:
fig = go.Figure() fig = go.Figure()
@ -170,10 +176,15 @@ def plot_delta(
_update_layout( _update_layout(
fig, fig,
x_label=f"{prev_name} Prevalence", x_label=f"{prev_name} Prevalence",
y_label=f"Prediction Error for {measure_name}", y_label=f"Prediction Error for {acc_mame}",
) )
return _save_or_return( return _save_or_return(
fig, basedir, dataset_name, measure_name, "delta" if stdevs is None else "stdev" fig,
basedir,
cls_name,
acc_mame,
dataset_name,
"delta" if stdevs is None else "stdev",
) )
@ -181,10 +192,11 @@ def plot_shift(
method_names: list[str], method_names: list[str],
prevs: np.ndarray, prevs: np.ndarray,
acc_errs: np.ndarray, acc_errs: np.ndarray,
cls_name,
acc_name,
dataset_name,
*, *,
counts: np.ndarray | None = None, counts: np.ndarray | None = None,
measure_name="Vanilla Accuracy",
dataset_name=None,
basedir=None, basedir=None,
) -> go.Figure: ) -> go.Figure:
fig = go.Figure() fig = go.Figure()
@ -212,6 +224,6 @@ def plot_shift(
_update_layout( _update_layout(
fig, fig,
x_label="Amount of Prior Probability Shift", x_label="Amount of Prior Probability Shift",
y_label=f"Prediction Error for {measure_name}", y_label=f"Prediction Error for {acc_name}",
) )
return _save_or_return(fig, basedir, dataset_name, measure_name, "shift") return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift")