1
0
Fork 0

unit test for replicability based on qp.util.temp_seed

This commit is contained in:
Alejandro Moreo Fernandez 2022-07-11 14:00:25 +02:00
parent 1742b75504
commit ecd0ad7ec7
2 changed files with 31 additions and 0 deletions

View File

@ -438,6 +438,7 @@ class PACC(AggregativeProbabilisticQuantifier):
validation data, or as an integer, indicating that the misclassification rates should be estimated via validation data, or as an integer, indicating that the misclassification rates should be estimated via
`k`-fold cross validation (this integer stands for the number of folds `k`), or as a `k`-fold cross validation (this integer stands for the number of folds `k`), or as a
:class:`quapy.data.base.LabelledCollection` (the split itself). :class:`quapy.data.base.LabelledCollection` (the split itself).
:param n_jobs: number of parallel workers
""" """
def __init__(self, learner: BaseEstimator, val_split=0.4, n_jobs=None): def __init__(self, learner: BaseEstimator, val_split=0.4, n_jobs=None):

View File

@ -0,0 +1,30 @@
import unittest
import quapy as qp
from quapy.functional import strprev
from sklearn.linear_model import LogisticRegression
from method.aggregative import PACC
class MyTestCase(unittest.TestCase):
def test_replicability(self):
dataset = qp.datasets.fetch_UCIDataset('yeast')
with qp.util.temp_seed(0):
lr = LogisticRegression(random_state=0, max_iter=10000)
pacc = PACC(lr)
prev = pacc.fit(dataset.training).quantify(dataset.test.X)
str_prev1 = strprev(prev, prec=5)
with qp.util.temp_seed(0):
lr = LogisticRegression(random_state=0, max_iter=10000)
pacc = PACC(lr)
prev2 = pacc.fit(dataset.training).quantify(dataset.test.X)
str_prev2 = strprev(prev2, prec=5)
self.assertEqual(str_prev1, str_prev2) # add assertion here
if __name__ == '__main__':
unittest.main()