diff --git a/quacc/dataset.py b/quacc/dataset.py index 3f6e179..c11b0c9 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -39,14 +39,13 @@ class Dataset: self._target = target self.prevs = None + self.n_prevs = n_prevalences if prevs is not None: prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0]) if prevs.shape[0] > 0: self.prevs = np.sort(prevs) self.n_prevs = self.prevs.shape[0] - self.n_prevs = n_prevalences - def __spambase(self): return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test @@ -88,7 +87,7 @@ class Dataset: return DatasetSample(train, val, test) def get(self) -> List[DatasetSample]: - all_train, test = { + (all_train, test) = { "spambase": self.__spambase, "imdb": self.__imdb, "rcv1": self.__rcv1, @@ -108,7 +107,7 @@ class Dataset: at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs) datasets = [] - for p in prevs: + for p in 1.0 - prevs: all_train_sampled = all_train.sampling(at_size, p, random_state=0) train, validation = all_train_sampled.split_stratified( train_prop=TRAIN_VAL_PROP, random_state=0 @@ -122,10 +121,11 @@ class Dataset: @property def name(self): - if self._name == "rcv1": - return f"{self._name}_{self._target}" - else: - return self._name + return ( + f"{self._name}_{self._target}_{self.n_prevs}prevs" + if self._name == "rcv1" + else f"{self._name}_{self.n_prevs}prevs" + ) # >>> fetch_rcv1().target_names