bugs fixed
This commit is contained in:
parent
c452c555c0
commit
d635c83161
|
@ -39,14 +39,13 @@ class Dataset:
|
||||||
self._target = target
|
self._target = target
|
||||||
|
|
||||||
self.prevs = None
|
self.prevs = None
|
||||||
|
self.n_prevs = n_prevalences
|
||||||
if prevs is not None:
|
if prevs is not None:
|
||||||
prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0])
|
prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0])
|
||||||
if prevs.shape[0] > 0:
|
if prevs.shape[0] > 0:
|
||||||
self.prevs = np.sort(prevs)
|
self.prevs = np.sort(prevs)
|
||||||
self.n_prevs = self.prevs.shape[0]
|
self.n_prevs = self.prevs.shape[0]
|
||||||
|
|
||||||
self.n_prevs = n_prevalences
|
|
||||||
|
|
||||||
def __spambase(self):
|
def __spambase(self):
|
||||||
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||||
|
|
||||||
|
@ -88,7 +87,7 @@ class Dataset:
|
||||||
return DatasetSample(train, val, test)
|
return DatasetSample(train, val, test)
|
||||||
|
|
||||||
def get(self) -> List[DatasetSample]:
|
def get(self) -> List[DatasetSample]:
|
||||||
all_train, test = {
|
(all_train, test) = {
|
||||||
"spambase": self.__spambase,
|
"spambase": self.__spambase,
|
||||||
"imdb": self.__imdb,
|
"imdb": self.__imdb,
|
||||||
"rcv1": self.__rcv1,
|
"rcv1": self.__rcv1,
|
||||||
|
@ -108,7 +107,7 @@ class Dataset:
|
||||||
|
|
||||||
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
|
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
|
||||||
datasets = []
|
datasets = []
|
||||||
for p in prevs:
|
for p in 1.0 - prevs:
|
||||||
all_train_sampled = all_train.sampling(at_size, p, random_state=0)
|
all_train_sampled = all_train.sampling(at_size, p, random_state=0)
|
||||||
train, validation = all_train_sampled.split_stratified(
|
train, validation = all_train_sampled.split_stratified(
|
||||||
train_prop=TRAIN_VAL_PROP, random_state=0
|
train_prop=TRAIN_VAL_PROP, random_state=0
|
||||||
|
@ -122,10 +121,11 @@ class Dataset:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
if self._name == "rcv1":
|
return (
|
||||||
return f"{self._name}_{self._target}"
|
f"{self._name}_{self._target}_{self.n_prevs}prevs"
|
||||||
else:
|
if self._name == "rcv1"
|
||||||
return self._name
|
else f"{self._name}_{self.n_prevs}prevs"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# >>> fetch_rcv1().target_names
|
# >>> fetch_rcv1().target_names
|
||||||
|
|
Loading…
Reference in New Issue