bugs fixed

This commit is contained in:
Lorenzo Volpi 2023-10-31 14:53:31 +01:00
parent c452c555c0
commit d635c83161
1 changed files with 8 additions and 8 deletions

View File

@ -39,14 +39,13 @@ class Dataset:
self._target = target self._target = target
self.prevs = None self.prevs = None
self.n_prevs = n_prevalences
if prevs is not None: if prevs is not None:
prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0]) prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0])
if prevs.shape[0] > 0: if prevs.shape[0] > 0:
self.prevs = np.sort(prevs) self.prevs = np.sort(prevs)
self.n_prevs = self.prevs.shape[0] self.n_prevs = self.prevs.shape[0]
self.n_prevs = n_prevalences
def __spambase(self): def __spambase(self):
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
@ -88,7 +87,7 @@ class Dataset:
return DatasetSample(train, val, test) return DatasetSample(train, val, test)
def get(self) -> List[DatasetSample]: def get(self) -> List[DatasetSample]:
all_train, test = { (all_train, test) = {
"spambase": self.__spambase, "spambase": self.__spambase,
"imdb": self.__imdb, "imdb": self.__imdb,
"rcv1": self.__rcv1, "rcv1": self.__rcv1,
@ -108,7 +107,7 @@ class Dataset:
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs) at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
datasets = [] datasets = []
for p in prevs: for p in 1.0 - prevs:
all_train_sampled = all_train.sampling(at_size, p, random_state=0) all_train_sampled = all_train.sampling(at_size, p, random_state=0)
train, validation = all_train_sampled.split_stratified( train, validation = all_train_sampled.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0 train_prop=TRAIN_VAL_PROP, random_state=0
@ -122,10 +121,11 @@ class Dataset:
@property @property
def name(self): def name(self):
if self._name == "rcv1": return (
return f"{self._name}_{self._target}" f"{self._name}_{self._target}_{self.n_prevs}prevs"
else: if self._name == "rcv1"
return self._name else f"{self._name}_{self.n_prevs}prevs"
)
# >>> fetch_rcv1().target_names # >>> fetch_rcv1().target_names