refatored random_state

This commit is contained in:
Lorenzo Volpi 2023-11-26 16:33:11 +01:00
parent 4dbabacb0d
commit e05dfd4a16
1 changed files with 9 additions and 6 deletions

View File

@ -2,7 +2,7 @@ import math
import os
import pickle
import tarfile
from typing import List
from typing import List, Tuple
import numpy as np
import quapy as qp
@ -11,6 +11,7 @@ from sklearn.conftest import fetch_rcv1
from sklearn.utils import Bunch
from quacc import utils
from quacc.environment import env
TRAIN_VAL_PROP = 0.5
@ -190,7 +191,7 @@ class Dataset:
return all_train, test
def __train_test(self):
def __train_test(self) -> Tuple[LabelledCollection, LabelledCollection]:
all_train, test = {
"spambase": self.__spambase,
"imdb": self.__imdb,
@ -205,7 +206,7 @@ class Dataset:
all_train, test = self.__train_test()
train, val = all_train.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0
train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
)
return DatasetSample(train, val, test)
@ -216,7 +217,9 @@ class Dataset:
# resample all_train set to have (0.5, 0.5) prevalence
at_positives = np.sum(all_train.y)
all_train = all_train.sampling(
min(at_positives, len(all_train) - at_positives) * 2, 0.5, random_state=0
min(at_positives, len(all_train) - at_positives) * 2,
0.5,
random_state=env._R_SEED,
)
# sample prevalences
@ -228,9 +231,9 @@ class Dataset:
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
datasets = []
for p in 1.0 - prevs:
all_train_sampled = all_train.sampling(at_size, p, random_state=0)
all_train_sampled = all_train.sampling(at_size, p, random_state=env._R_SEED)
train, validation = all_train_sampled.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0
train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED
)
datasets.append(DatasetSample(train, validation, test))