From ecd0ad7ec7c40db811c045aea0b7ee3c71e97594 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Mon, 11 Jul 2022 14:00:25 +0200 Subject: [PATCH] unit test for replicability based on qp.util.temp_seed --- quapy/method/aggregative.py | 1 + quapy/tests/test_replicability.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 quapy/tests/test_replicability.py diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 759a853..e40e96c 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -438,6 +438,7 @@ class PACC(AggregativeProbabilisticQuantifier): validation data, or as an integer, indicating that the misclassification rates should be estimated via `k`-fold cross validation (this integer stands for the number of folds `k`), or as a :class:`quapy.data.base.LabelledCollection` (the split itself). + :param n_jobs: number of parallel workers """ def __init__(self, learner: BaseEstimator, val_split=0.4, n_jobs=None): diff --git a/quapy/tests/test_replicability.py b/quapy/tests/test_replicability.py new file mode 100644 index 0000000..329ac32 --- /dev/null +++ b/quapy/tests/test_replicability.py @@ -0,0 +1,30 @@ +import unittest +import quapy as qp +from quapy.functional import strprev +from sklearn.linear_model import LogisticRegression + +from method.aggregative import PACC + + +class MyTestCase(unittest.TestCase): + def test_replicability(self): + + dataset = qp.datasets.fetch_UCIDataset('yeast') + + with qp.util.temp_seed(0): + lr = LogisticRegression(random_state=0, max_iter=10000) + pacc = PACC(lr) + prev = pacc.fit(dataset.training).quantify(dataset.test.X) + str_prev1 = strprev(prev, prec=5) + + with qp.util.temp_seed(0): + lr = LogisticRegression(random_state=0, max_iter=10000) + pacc = PACC(lr) + prev2 = pacc.fit(dataset.training).quantify(dataset.test.X) + str_prev2 = strprev(prev2, prec=5) + + self.assertEqual(str_prev1, str_prev2) # add assertion here + + +if __name__ == '__main__': + unittest.main()