rcv1 updated
This commit is contained in:
parent
37392c6545
commit
2235fd35c0
|
@ -13,23 +13,32 @@ def get_imdb() -> Tuple[LabelledCollection]:
|
||||||
return train, validation, test
|
return train, validation, test
|
||||||
|
|
||||||
|
|
||||||
def get_spambase():
|
def get_spambase() -> Tuple[LabelledCollection]:
|
||||||
train, test = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
train, test = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||||
train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP)
|
train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP)
|
||||||
return train, validation, test
|
return train, validation, test
|
||||||
|
|
||||||
|
|
||||||
def get_rcv1(sample_size=100):
|
def get_rcv1(sample_size=100):
|
||||||
|
n_train = 23149
|
||||||
dataset = fetch_rcv1()
|
dataset = fetch_rcv1()
|
||||||
|
|
||||||
|
def dataset_split(data, labels, classes=[0, 1]) -> Tuple[LabelledCollection]:
|
||||||
|
all_train_d, test_d = data[:n_train, :], data[n_train:, :]
|
||||||
|
all_train_l, test_l = labels[:n_train], labels[n_train:]
|
||||||
|
all_train = LabelledCollection(all_train_d, all_train_l, classes=classes)
|
||||||
|
test = LabelledCollection(test_d, test_l, classes=classes)
|
||||||
|
train, validation = all_train.split_stratified(train_prop=TRAIN_VAL_PROP)
|
||||||
|
return train, validation, test
|
||||||
|
|
||||||
target_labels = [
|
target_labels = [
|
||||||
(target, dataset.target[:, ind].toarray().flatten())
|
(target, dataset.target[:, ind].toarray().flatten())
|
||||||
for (ind, target) in enumerate(dataset.target_names)
|
for (ind, target) in enumerate(dataset.target_names)
|
||||||
]
|
]
|
||||||
filtered_target_labels = filter(
|
filtered_target_labels = filter(
|
||||||
lambda _, labels: np.sum(labels) >= sample_size, target_labels
|
lambda _, labels: np.sum(labels[n_train:]) >= sample_size, target_labels
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
target: LabelledCollection(dataset.data, labels, classes=[0, 1])
|
target: dataset_split(dataset.data, labels, classes=[0, 1])
|
||||||
for (target, labels) in filtered_target_labels
|
for (target, labels) in filtered_target_labels
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue