diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index ba28d6d..531eb3b 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -54,6 +54,15 @@ class EvaluationReport: class CompReport: + _default_modes = [ + "delta", + "delta_stdev", + "diagonal", + "shift", + "table", + "shift_table", + ] + def __init__( self, reports: List[EvaluationReport], @@ -233,6 +242,8 @@ class CompReport: res += fmt_line_md(f"train: {str(self.train_prev)}") res += fmt_line_md(f"validation: {str(self.valid_prev)}") for k, v in self.times.items(): + if estimators is not None and k not in estimators: + continue res += fmt_line_md(f"{k}: {v:.3f}s") res += "\n" if "table" in modes: @@ -261,6 +272,18 @@ class CompReport: class DatasetReport: + _default_dr_modes = [ + "delta_train", + "stdev_train", + "delta_test", + "stdev_test", + "shift", + "train_table", + "test_table", + "shift_table", + ] + _default_cr_modes = CompReport._default_modes + def __init__(self, name, crs=None): self.name = name self.crs: List[CompReport] = [] if crs is None else crs @@ -421,28 +444,18 @@ class DatasetReport: conf="default", metric="acc", estimators=[], - dr_modes=[ - "delta_train", - "stdev_train", - "delta_test", - "stdev_test", - "shift", - "train_table", - "test_table", - "shift_table", - ], - cr_modes=[ - "delta", - "delta_stdev", - "diagonal", - "shift", - "table", - "shift_table", - ], + dr_modes=_default_dr_modes, + cr_modes=_default_cr_modes, + cr_prevs: List[str] = None, plot_path=None, ): res = f"# {self.name}\n\n" for cr in self.crs: + if ( + cr_prevs is not None + and str(round(cr.train_prev[1] * 100)) not in cr_prevs + ): + continue res += f"{cr.to_md(conf, metric=metric, estimators=estimators, modes=cr_modes, plot_path=plot_path)}\n\n" _data = self.data(metric=metric, estimators=estimators) @@ -535,6 +548,8 @@ class DatasetReport: with open(pickle_path, "wb") as f: pickle.dump(self, f) + return self + @classmethod def unpickle(cls, pickle_path: Path): with open(pickle_path, "rb") as f: