From e05dfd4a16d54707fee6fc0d998259d6c9ef3d3b Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Sun, 26 Nov 2023 16:33:11 +0100 Subject: [PATCH] refatored random_state --- quacc/dataset.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/quacc/dataset.py b/quacc/dataset.py index ce97aec..35a20db 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -2,7 +2,7 @@ import math import os import pickle import tarfile -from typing import List +from typing import List, Tuple import numpy as np import quapy as qp @@ -11,6 +11,7 @@ from sklearn.conftest import fetch_rcv1 from sklearn.utils import Bunch from quacc import utils +from quacc.environment import env TRAIN_VAL_PROP = 0.5 @@ -190,7 +191,7 @@ class Dataset: return all_train, test - def __train_test(self): + def __train_test(self) -> Tuple[LabelledCollection, LabelledCollection]: all_train, test = { "spambase": self.__spambase, "imdb": self.__imdb, @@ -205,7 +206,7 @@ class Dataset: all_train, test = self.__train_test() train, val = all_train.split_stratified( - train_prop=TRAIN_VAL_PROP, random_state=0 + train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED ) return DatasetSample(train, val, test) @@ -216,7 +217,9 @@ class Dataset: # resample all_train set to have (0.5, 0.5) prevalence at_positives = np.sum(all_train.y) all_train = all_train.sampling( - min(at_positives, len(all_train) - at_positives) * 2, 0.5, random_state=0 + min(at_positives, len(all_train) - at_positives) * 2, + 0.5, + random_state=env._R_SEED, ) # sample prevalences @@ -228,9 +231,9 @@ class Dataset: at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs) datasets = [] for p in 1.0 - prevs: - all_train_sampled = all_train.sampling(at_size, p, random_state=0) + all_train_sampled = all_train.sampling(at_size, p, random_state=env._R_SEED) train, validation = all_train_sampled.split_stratified( - train_prop=TRAIN_VAL_PROP, random_state=0 + train_prop=TRAIN_VAL_PROP, random_state=env._R_SEED ) datasets.append(DatasetSample(train, validation, test))