diff --git a/quacc/data.py b/quacc/data.py index a422636..7606605 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -64,27 +64,25 @@ class ExtendedData: return self.instances def __split_index_by_pred(self) -> List[np.ndarray]: - _pred_label = np.argmax(self.pred_proba_, axis=0) + _pred_label = np.argmax(self.pred_proba_, axis=1) return [ (_pred_label == cl).nonzero()[0] - for cl in np.arange(self.pred_proba_.shape[0]) + for cl in np.arange(self.pred_proba_.shape[1]) ] def split_by_pred(self, return_indexes=False): + def _empty_matrix(): + if isinstance(self.instances, np.ndarray): + return np.asarray([], dtype=int) + elif isinstance(self.instances, sp.csr_matrix): + return sp.csr_matrix(np.empty((0, 0), dtype=int)) + _indexes = self.__split_index_by_pred() - if isinstance(self.instances, np.ndarray): - _instances = [ - self.instances[ind] if ind.shape[0] > 0 else np.asarray([], dtype=int) - for ind in _indexes - ] - elif isinstance(self.instances, sp.csr_matrix): - _instances = [ - self.instances[ind] - if ind.shape[0] > 0 - else sp.csr_matrix(np.empty((0, 0), dtype=int)) - for ind in _indexes - ] + _instances = [ + self.instances[ind] if ind.shape[0] > 0 else _empty_matrix() + for ind in _indexes + ] if return_indexes: return _instances, _indexes @@ -182,7 +180,7 @@ class ExtendedCollection(LabelledCollection): return _counts def split_by_pred(self): - _ncl = len(self.pred_proba) + _ncl = self.pred_proba.shape[1] _instances, _indexes = self.e_data_.split_by_pred(return_indexes=True) _labels = [self.ey[ind] for ind in _indexes] return [