From 6fbe825399a7a21eae5aa14b07c6e238436376bc Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Sun, 26 Nov 2023 16:31:40 +0100 Subject: [PATCH] times refactored --- quacc/evaluation/report.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index 717d8e6..d7ebbc4 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -75,6 +75,7 @@ class CompReport: train_prev: np.ndarray = None, valid_prev: np.ndarray = None, times=None, + g_time=None, ): if isinstance(datas, pd.DataFrame): self._data: pd.DataFrame = datas @@ -90,9 +91,14 @@ class CompReport: .sort_index(axis=0, level=0) ) + if times is None: + self.times = {er.name: er.time for er in datas} + else: + self.times = times + + self.times["tot"] = g_time self.train_prev = train_prev self.valid_prev = valid_prev - self.times = times @property def prevs(self) -> np.ndarray: @@ -130,9 +136,10 @@ class CompReport: df = CompReport( _join, self.name if hasattr(self, "name") else "default", - self.train_prev, - self.valid_prev, - self.times | other.times, + train_prev=self.train_prev, + valid_prev=self.valid_prev, + times=self.times | other.times, + g_time=self.times["tot"] + other.times["tot"], ) return df