bugs fixed

This commit is contained in:
Lorenzo Volpi 2023-10-31 14:53:31 +01:00
parent c452c555c0
commit d635c83161
1 changed files with 8 additions and 8 deletions

View File

@ -39,14 +39,13 @@ class Dataset:
self._target = target
self.prevs = None
self.n_prevs = n_prevalences
if prevs is not None:
prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0])
if prevs.shape[0] > 0:
self.prevs = np.sort(prevs)
self.n_prevs = self.prevs.shape[0]
self.n_prevs = n_prevalences
def __spambase(self):
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
@ -88,7 +87,7 @@ class Dataset:
return DatasetSample(train, val, test)
def get(self) -> List[DatasetSample]:
all_train, test = {
(all_train, test) = {
"spambase": self.__spambase,
"imdb": self.__imdb,
"rcv1": self.__rcv1,
@ -108,7 +107,7 @@ class Dataset:
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
datasets = []
for p in prevs:
for p in 1.0 - prevs:
all_train_sampled = all_train.sampling(at_size, p, random_state=0)
train, validation = all_train_sampled.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0
@ -122,10 +121,11 @@ class Dataset:
@property
def name(self):
if self._name == "rcv1":
return f"{self._name}_{self._target}"
else:
return self._name
return (
f"{self._name}_{self._target}_{self.n_prevs}prevs"
if self._name == "rcv1"
else f"{self._name}_{self.n_prevs}prevs"
)
# >>> fetch_rcv1().target_names