diff --git a/quacc/data.py b/quacc/data.py index 7606605..2e58936 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -45,9 +45,9 @@ class ExtendedData: pred_proba: np.ndarray, ext: np.ndarray = None, ) -> np.ndarray | sp.csr_matrix: - to_append = pred_proba - if ext is not None: - to_append = np.concatenate([ext, pred_proba], axis=1) + to_append = ext + if ext is None: + to_append = pred_proba if isinstance(instances, sp.csr_matrix): _to_append = sp.csr_matrix(to_append)