forked from moreo/QuaPy
unit test for replicability based on qp.util.temp_seed
This commit is contained in:
parent
1742b75504
commit
ecd0ad7ec7
|
@ -438,6 +438,7 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
||||
`k`-fold cross validation (this integer stands for the number of folds `k`), or as a
|
||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||
:param n_jobs: number of parallel workers
|
||||
"""
|
||||
|
||||
def __init__(self, learner: BaseEstimator, val_split=0.4, n_jobs=None):
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import unittest
|
||||
import quapy as qp
|
||||
from quapy.functional import strprev
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
from method.aggregative import PACC
|
||||
|
||||
|
||||
class MyTestCase(unittest.TestCase):
|
||||
def test_replicability(self):
|
||||
|
||||
dataset = qp.datasets.fetch_UCIDataset('yeast')
|
||||
|
||||
with qp.util.temp_seed(0):
|
||||
lr = LogisticRegression(random_state=0, max_iter=10000)
|
||||
pacc = PACC(lr)
|
||||
prev = pacc.fit(dataset.training).quantify(dataset.test.X)
|
||||
str_prev1 = strprev(prev, prec=5)
|
||||
|
||||
with qp.util.temp_seed(0):
|
||||
lr = LogisticRegression(random_state=0, max_iter=10000)
|
||||
pacc = PACC(lr)
|
||||
prev2 = pacc.fit(dataset.training).quantify(dataset.test.X)
|
||||
str_prev2 = strprev(prev2, prec=5)
|
||||
|
||||
self.assertEqual(str_prev1, str_prev2) # add assertion here
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue