report adapred to qcpanel save feature

This commit is contained in:
Lorenzo Volpi 2023-11-16 01:34:24 +01:00
parent f89f92a758
commit d9bc268bb0
1 changed files with 115 additions and 47 deletions

View File

@ -6,7 +6,6 @@ import numpy as np
import pandas as pd
from quacc import plot
from quacc.environment import env
from quacc.utils import fmt_line_md
@ -145,6 +144,14 @@ class CompReport:
avg_p.loc["avg", :] = f_data.mean()
return avg_p
def shift_table(
self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
f_data = self.shift_data(metric=metric, estimators=estimators)
avg_p = f_data.groupby(level=0).mean()
avg_p.loc["avg", :] = f_data.mean()
return avg_p
def get_plots(
self,
mode="delta",
@ -152,6 +159,7 @@ class CompReport:
estimators=None,
conf="default",
return_fig=False,
base_path=None,
) -> List[Tuple[str, Path]]:
if mode == "delta":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
@ -163,6 +171,7 @@ class CompReport:
name=conf,
train_prev=self.train_prev,
return_fig=return_fig,
base_path=base_path,
)
elif mode == "delta_stdev":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
@ -176,6 +185,7 @@ class CompReport:
train_prev=self.train_prev,
stdevs=st_data.T.to_numpy(),
return_fig=return_fig,
base_path=base_path,
)
elif mode == "diagonal":
f_data = self.data(metric=metric + "_score", estimators=estimators)
@ -189,6 +199,7 @@ class CompReport:
name=conf,
train_prev=self.train_prev,
return_fig=return_fig,
base_path=base_path,
)
elif mode == "shift":
_shift_data = self.shift_data(metric=metric, estimators=estimators)
@ -207,30 +218,44 @@ class CompReport:
train_prev=self.train_prev,
counts=shift_counts.T.to_numpy(),
return_fig=return_fig,
base_path=base_path,
)
def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str:
def to_md(
self,
conf="default",
metric="acc",
estimators=None,
modes=["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"],
plot_path=None,
) -> str:
res = f"## {int(np.around(self.train_prev, decimals=2)[1]*100)}% positives\n"
res += fmt_line_md(f"train: {str(self.train_prev)}")
res += fmt_line_md(f"validation: {str(self.valid_prev)}")
for k, v in self.times.items():
res += fmt_line_md(f"{k}: {v:.3f}s")
res += "\n"
res += self.table(metric=metric, estimators=estimators).to_html() + "\n\n"
if "table" in modes:
res += "### table\n"
res += self.table(metric=metric, estimators=estimators).to_html() + "\n\n"
if "shift_table" in modes:
res += "### shift table\n"
res += (
self.shift_table(metric=metric, estimators=estimators).to_html()
+ "\n\n"
)
plot_modes = np.array(["delta", "diagonal", "shift"], dtype="object")
if stdev:
whd = np.where(plot_modes == "delta")[0]
if len(whd) > 0:
plot_modes = np.insert(plot_modes, whd + 1, "delta_stdev")
plot_modes = [m for m in modes if m not in ["table", "shift_table"]]
for mode in plot_modes:
res += f"### {mode}\n"
op = self.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
)
res += f"![plot_{mode}]({op.relative_to(env.OUT_DIR).as_posix()})\n"
res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n"
return res
@ -304,6 +329,7 @@ class DatasetReport:
estimators=None,
conf="default",
return_fig=False,
base_path=None,
):
if mode == "delta_train":
_data = self.data(metric, estimators) if data is None else data
@ -320,6 +346,7 @@ class DatasetReport:
train_prev=None,
avg="train",
return_fig=return_fig,
base_path=base_path,
)
elif mode == "stdev_train":
_data = self.data(metric, estimators) if data is None else data
@ -338,6 +365,7 @@ class DatasetReport:
stdevs=stdev_on_train.T.to_numpy(),
avg="train",
return_fig=return_fig,
base_path=base_path,
)
elif mode == "delta_test":
_data = self.data(metric, estimators) if data is None else data
@ -352,6 +380,7 @@ class DatasetReport:
train_prev=None,
avg="test",
return_fig=return_fig,
base_path=base_path,
)
elif mode == "stdev_test":
_data = self.data(metric, estimators) if data is None else data
@ -368,6 +397,7 @@ class DatasetReport:
stdevs=stdev_on_test.T.to_numpy(),
avg="test",
return_fig=return_fig,
base_path=base_path,
)
elif mode == "shift":
_shift_data = self.shift_data(metric, estimators) if data is None else data
@ -383,12 +413,37 @@ class DatasetReport:
train_prev=None,
counts=count_shift.T.to_numpy(),
return_fig=return_fig,
base_path=base_path,
)
def to_md(self, conf="default", metric="acc", estimators=[], stdev=False):
def to_md(
self,
conf="default",
metric="acc",
estimators=[],
dr_modes=[
"delta_train",
"stdev_train",
"delta_test",
"stdev_test",
"shift",
"train_table",
"test_table",
"shift_table",
],
cr_modes=[
"delta",
"delta_stdev",
"diagonal",
"shift",
"table",
"shift_table",
],
plot_path=None,
):
res = f"# {self.name}\n\n"
for cr in self.crs:
res += f"{cr.to_md(conf, metric=metric, estimators=estimators, stdev=stdev)}\n\n"
res += f"{cr.to_md(conf, metric=metric, estimators=estimators, modes=cr_modes, plot_path=plot_path)}\n\n"
_data = self.data(metric=metric, estimators=estimators)
_shift_data = self.shift_data(metric=metric, estimators=estimators)
@ -398,68 +453,81 @@ class DatasetReport:
######################## avg on train ########################
res += "### avg on train\n"
avg_on_train_tbl = _data.groupby(level=1).mean()
avg_on_train_tbl.loc["avg", :] = _data.mean()
if "train_table" in dr_modes:
avg_on_train_tbl = _data.groupby(level=1).mean()
avg_on_train_tbl.loc["avg", :] = _data.mean()
res += avg_on_train_tbl.to_html() + "\n\n"
res += avg_on_train_tbl.to_html() + "\n\n"
if "delta_train" in dr_modes:
delta_op = self.get_plots(
data=_data,
mode="delta_train",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
)
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
delta_op = self.get_plots(
data=_data,
mode="delta_train",
metric=metric,
estimators=estimators,
conf=conf,
)
res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n"
if stdev:
if "stdev_train" in dr_modes:
delta_stdev_op = self.get_plots(
data=_data,
mode="stdev_train",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
)
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n"
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
######################## avg on test ########################
res += "### avg on test\n"
avg_on_test_tbl = _data.groupby(level=0).mean()
avg_on_test_tbl.loc["avg", :] = _data.mean()
if "test_table" in dr_modes:
avg_on_test_tbl = _data.groupby(level=0).mean()
avg_on_test_tbl.loc["avg", :] = _data.mean()
res += avg_on_test_tbl.to_html() + "\n\n"
res += avg_on_test_tbl.to_html() + "\n\n"
if "delta_test" in dr_modes:
delta_op = self.get_plots(
data=_data,
mode="delta_test",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
)
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
delta_op = self.get_plots(
data=_data,
mode="delta_test",
metric=metric,
estimators=estimators,
conf=conf,
)
res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n"
if stdev:
if "stdev_test" in dr_modes:
delta_stdev_op = self.get_plots(
data=_data,
mode="stdev_test",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
)
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n"
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
######################## avg shift ########################
res += "### avg dataset shift\n"
shift_op = self.get_plots(
data=_shift_data,
mode="shift",
metric=metric,
estimators=estimators,
conf=conf,
)
res += f"![plot_shift]({shift_op.relative_to(env.OUT_DIR).as_posix()})\n"
if "shift_table" in dr_modes:
shift_on_train_tbl = _shift_data.groupby(level=0).mean()
shift_on_train_tbl.loc["avg", :] = _shift_data.mean()
res += shift_on_train_tbl.to_html() + "\n\n"
if "shift" in dr_modes:
shift_op = self.get_plots(
data=_shift_data,
mode="shift",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
)
res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n"
return res