Dataset.reduce() allows to fix the random_state to have reproducible unit tests. This is required to ensure that the expected hyper-parameters are always chosen, independent of randomness

This commit is contained in:
Mirko Bunse 2024-04-17 14:46:37 +02:00
parent 72b43bd2f8
commit a64620c377
2 changed files with 14 additions and 6 deletions

View File

@ -549,7 +549,7 @@ class Dataset:
yield Dataset(train, test, name=f'fold {(i % nfolds) + 1}/{nfolds} (round={(i // nfolds) + 1})') yield Dataset(train, test, name=f'fold {(i % nfolds) + 1}/{nfolds} (round={(i // nfolds) + 1})')
def reduce(self, n_train=100, n_test=100): def reduce(self, n_train=100, n_test=100, random_state=None):
""" """
Reduce the number of instances in place for quick experiments. Preserves the prevalence of each set. Reduce the number of instances in place for quick experiments. Preserves the prevalence of each set.
@ -557,6 +557,14 @@ class Dataset:
:param n_test: number of test documents to keep (default 100) :param n_test: number of test documents to keep (default 100)
:return: self :return: self
""" """
self.training = self.training.sampling(n_train, *self.training.prevalence()) self.training = self.training.sampling(
self.test = self.test.sampling(n_test, *self.test.prevalence()) n_train,
*self.training.prevalence(),
random_state = random_state
)
self.test = self.test.sampling(
n_test,
*self.test.prevalence(),
random_state = random_state
)
return self return self

View File

@ -19,7 +19,7 @@ class ModselTestCase(unittest.TestCase):
q = PACC(LogisticRegression(random_state=1, max_iter=5000)) q = PACC(LogisticRegression(random_state=1, max_iter=5000))
data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=10).reduce() data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=10).reduce(random_state=1)
training, validation = data.training.split_stratified(0.7, random_state=1) training, validation = data.training.split_stratified(0.7, random_state=1)
param_grid = {'classifier__C': [0.000001, 10.]} param_grid = {'classifier__C': [0.000001, 10.]}
@ -41,7 +41,7 @@ class ModselTestCase(unittest.TestCase):
q = PACC(LogisticRegression(random_state=1, max_iter=5000)) q = PACC(LogisticRegression(random_state=1, max_iter=5000))
data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=10).reduce(n_train=500) data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=10).reduce(n_train=500, random_state=1)
training, validation = data.training.split_stratified(0.7, random_state=1) training, validation = data.training.split_stratified(0.7, random_state=1)
param_grid = {'classifier__C': np.logspace(-3,3,7)} param_grid = {'classifier__C': np.logspace(-3,3,7)}
@ -79,7 +79,7 @@ class ModselTestCase(unittest.TestCase):
q = PACC(SlowLR()) q = PACC(SlowLR())
data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=10).reduce() data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=10).reduce(random_state=1)
training, validation = data.training.split_stratified(0.7, random_state=1) training, validation = data.training.split_stratified(0.7, random_state=1)
param_grid = {'classifier__C': np.logspace(-1,1,3)} param_grid = {'classifier__C': np.logspace(-1,1,3)}