bugs fixed
This commit is contained in:
parent
c452c555c0
commit
d635c83161
|
@ -39,14 +39,13 @@ class Dataset:
|
|||
self._target = target
|
||||
|
||||
self.prevs = None
|
||||
self.n_prevs = n_prevalences
|
||||
if prevs is not None:
|
||||
prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0])
|
||||
if prevs.shape[0] > 0:
|
||||
self.prevs = np.sort(prevs)
|
||||
self.n_prevs = self.prevs.shape[0]
|
||||
|
||||
self.n_prevs = n_prevalences
|
||||
|
||||
def __spambase(self):
|
||||
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||
|
||||
|
@ -88,7 +87,7 @@ class Dataset:
|
|||
return DatasetSample(train, val, test)
|
||||
|
||||
def get(self) -> List[DatasetSample]:
|
||||
all_train, test = {
|
||||
(all_train, test) = {
|
||||
"spambase": self.__spambase,
|
||||
"imdb": self.__imdb,
|
||||
"rcv1": self.__rcv1,
|
||||
|
@ -108,7 +107,7 @@ class Dataset:
|
|||
|
||||
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
|
||||
datasets = []
|
||||
for p in prevs:
|
||||
for p in 1.0 - prevs:
|
||||
all_train_sampled = all_train.sampling(at_size, p, random_state=0)
|
||||
train, validation = all_train_sampled.split_stratified(
|
||||
train_prop=TRAIN_VAL_PROP, random_state=0
|
||||
|
@ -122,10 +121,11 @@ class Dataset:
|
|||
|
||||
@property
|
||||
def name(self):
|
||||
if self._name == "rcv1":
|
||||
return f"{self._name}_{self._target}"
|
||||
else:
|
||||
return self._name
|
||||
return (
|
||||
f"{self._name}_{self._target}_{self.n_prevs}prevs"
|
||||
if self._name == "rcv1"
|
||||
else f"{self._name}_{self.n_prevs}prevs"
|
||||
)
|
||||
|
||||
|
||||
# >>> fetch_rcv1().target_names
|
||||
|
|
Loading…
Reference in New Issue