Delta plot added
|
@ -12,4 +12,3 @@ elsahar19_rca/__pycache__/*
|
||||||
*.coverage
|
*.coverage
|
||||||
.coverage
|
.coverage
|
||||||
scp_sync.py
|
scp_sync.py
|
||||||
out/*
|
|
After Width: | Height: | Size: 188 KiB |
After Width: | Height: | Size: 198 KiB |
After Width: | Height: | Size: 225 KiB |
After Width: | Height: | Size: 244 KiB |
After Width: | Height: | Size: 266 KiB |
After Width: | Height: | Size: 231 KiB |
After Width: | Height: | Size: 200 KiB |
After Width: | Height: | Size: 192 KiB |
After Width: | Height: | Size: 175 KiB |
|
@ -69,6 +69,12 @@ class EvaluationReport:
|
||||||
columns=g_dict.keys(),
|
columns=g_dict.keys(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_plot(self, mode="delta", metric="acc"):
|
||||||
|
g_prevs, g_dict = self.groupby_prevs(metric=metric)
|
||||||
|
t_prev = int(round(self.train_prevs["train"][0] * 100))
|
||||||
|
title = f"{self.name}_{t_prev}_{metric}"
|
||||||
|
plot.plot_delta(g_prevs, g_dict, metric, title)
|
||||||
|
|
||||||
def to_md(self, *metrics):
|
def to_md(self, *metrics):
|
||||||
res = ""
|
res = ""
|
||||||
for k, v in self.train_prevs.items():
|
for k, v in self.train_prevs.items():
|
||||||
|
@ -78,6 +84,7 @@ class EvaluationReport:
|
||||||
res += "\n"
|
res += "\n"
|
||||||
for m in metrics:
|
for m in metrics:
|
||||||
res += self.get_dataframe(metric=m).to_html() + "\n\n"
|
res += self.get_dataframe(metric=m).to_html() + "\n\n"
|
||||||
|
self.get_plot(metric=m)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from quacc.environ import env
|
||||||
|
|
||||||
|
|
||||||
|
def plot_delta(base_prevs, dict_vals, metric, title):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
base_prevs = [f for f, p in base_prevs]
|
||||||
|
for method, deltas in dict_vals.items():
|
||||||
|
ax.plot(
|
||||||
|
base_prevs,
|
||||||
|
deltas,
|
||||||
|
label=method,
|
||||||
|
linestyle="-",
|
||||||
|
marker="o",
|
||||||
|
markersize=3,
|
||||||
|
zorder=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
||||||
|
# ax.set_ylim(0, 1)
|
||||||
|
# ax.set_xlim(0, 1)
|
||||||
|
ax.legend()
|
||||||
|
output_path = env.PLOT_OUT_DIR / f"{title}.png"
|
||||||
|
plt.savefig(output_path)
|