2022-05-25 19:14:33 +02:00
|
|
|
import unittest
|
|
|
|
import numpy as np
|
2022-06-01 18:28:59 +02:00
|
|
|
from quapy.data import LabelledCollection
|
2023-02-14 11:14:38 +01:00
|
|
|
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
def mock_labelled_collection(prefix=''):
|
|
|
|
y = [0] * 250 + [1] * 250 + [2] * 250 + [3] * 250
|
|
|
|
X = [prefix + str(i) + '-' + str(yi) for i, yi in enumerate(y)]
|
2023-02-09 19:39:16 +01:00
|
|
|
return LabelledCollection(X, y, classes=sorted(np.unique(y)))
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
def samples_to_str(protocol):
|
|
|
|
samples_str = ""
|
2022-06-03 18:02:52 +02:00
|
|
|
for instances, prev in protocol():
|
|
|
|
samples_str += f'{instances}\t{prev}\n'
|
2022-05-25 19:14:33 +02:00
|
|
|
return samples_str
|
|
|
|
|
|
|
|
|
|
|
|
class TestProtocols(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_app_replicate(self):
|
|
|
|
data = mock_labelled_collection()
|
2022-06-21 10:27:06 +02:00
|
|
|
p = APP(data, sample_size=5, n_prevalences=11, random_state=42)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertEqual(samples1, samples2)
|
|
|
|
|
2023-02-10 19:02:17 +01:00
|
|
|
p = APP(data, sample_size=5, n_prevalences=11) # <- random_state is by default set to 0
|
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertEqual(samples1, samples2)
|
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
def test_app_not_replicate(self):
|
|
|
|
data = mock_labelled_collection()
|
2023-02-10 19:02:17 +01:00
|
|
|
p = APP(data, sample_size=5, n_prevalences=11, random_state=None)
|
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
2023-02-10 19:02:17 +01:00
|
|
|
self.assertNotEqual(samples1, samples2)
|
|
|
|
|
|
|
|
p = APP(data, sample_size=5, n_prevalences=11, random_state=42)
|
2022-05-25 19:14:33 +02:00
|
|
|
samples1 = samples_to_str(p)
|
2023-02-10 19:02:17 +01:00
|
|
|
p = APP(data, sample_size=5, n_prevalences=11, random_state=0)
|
2022-05-25 19:14:33 +02:00
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertNotEqual(samples1, samples2)
|
|
|
|
|
|
|
|
def test_app_number(self):
|
|
|
|
data = mock_labelled_collection()
|
|
|
|
p = APP(data, sample_size=100, n_prevalences=10, repeats=1)
|
|
|
|
|
|
|
|
# surprisingly enough, for some n_prevalences the test fails, notwithstanding
|
|
|
|
# everything is correct. The problem is that in function APP.prevalence_grid()
|
|
|
|
# there is sometimes one rounding error that gets cumulated and
|
|
|
|
# surpasses 1.0 (by a very small float value, 0.0000000000002 or sthe like)
|
|
|
|
# so these tuples are mistakenly removed... I have tried with np.close, and
|
|
|
|
# other workarounds, but eventually happens that there is some negative probability
|
|
|
|
# in the sampling function...
|
|
|
|
|
|
|
|
count = 0
|
|
|
|
for _ in p():
|
|
|
|
count+=1
|
|
|
|
|
|
|
|
self.assertEqual(count, p.total())
|
|
|
|
|
|
|
|
def test_npp_replicate(self):
|
|
|
|
data = mock_labelled_collection()
|
2022-06-21 10:27:06 +02:00
|
|
|
p = NPP(data, sample_size=5, repeats=5, random_state=42)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertEqual(samples1, samples2)
|
|
|
|
|
2023-02-10 19:02:17 +01:00
|
|
|
p = NPP(data, sample_size=5, repeats=5) # <- random_state is by default set to 0
|
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertEqual(samples1, samples2)
|
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
def test_npp_not_replicate(self):
|
|
|
|
data = mock_labelled_collection()
|
2023-02-10 19:02:17 +01:00
|
|
|
p = NPP(data, sample_size=5, repeats=5, random_state=None)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertNotEqual(samples1, samples2)
|
|
|
|
|
2023-02-10 19:02:17 +01:00
|
|
|
p = NPP(data, sample_size=5, repeats=5, random_state=42)
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
p = NPP(data, sample_size=5, repeats=5, random_state=0)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
self.assertNotEqual(samples1, samples2)
|
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
def test_kraemer_replicate(self):
|
|
|
|
data = mock_labelled_collection()
|
2023-02-14 11:14:38 +01:00
|
|
|
p = UPP(data, sample_size=5, repeats=10, random_state=42)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertEqual(samples1, samples2)
|
|
|
|
|
2023-02-14 11:14:38 +01:00
|
|
|
p = UPP(data, sample_size=5, repeats=10) # <- random_state is by default set to 0
|
2023-02-10 19:02:17 +01:00
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertEqual(samples1, samples2)
|
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
def test_kraemer_not_replicate(self):
|
|
|
|
data = mock_labelled_collection()
|
2023-02-14 11:14:38 +01:00
|
|
|
p = UPP(data, sample_size=5, repeats=10, random_state=None)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertNotEqual(samples1, samples2)
|
|
|
|
|
|
|
|
def test_covariate_shift_replicate(self):
|
|
|
|
dataA = mock_labelled_collection('domA')
|
|
|
|
dataB = mock_labelled_collection('domB')
|
2023-02-08 19:06:53 +01:00
|
|
|
p = DomainMixer(dataA, dataB, sample_size=10, mixture_points=11, random_state=1)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertEqual(samples1, samples2)
|
|
|
|
|
2023-02-10 19:02:17 +01:00
|
|
|
p = DomainMixer(dataA, dataB, sample_size=10, mixture_points=11) # <- random_state is by default set to 0
|
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertEqual(samples1, samples2)
|
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
def test_covariate_shift_not_replicate(self):
|
|
|
|
dataA = mock_labelled_collection('domA')
|
|
|
|
dataB = mock_labelled_collection('domB')
|
2023-02-10 19:02:17 +01:00
|
|
|
p = DomainMixer(dataA, dataB, sample_size=10, mixture_points=11, random_state=None)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
samples1 = samples_to_str(p)
|
|
|
|
samples2 = samples_to_str(p)
|
|
|
|
|
|
|
|
self.assertNotEqual(samples1, samples2)
|
|
|
|
|
|
|
|
def test_no_seed_init(self):
|
|
|
|
class NoSeedInit(AbstractStochasticSeededProtocol):
|
|
|
|
def __init__(self):
|
|
|
|
self.data = mock_labelled_collection()
|
|
|
|
|
|
|
|
def samples_parameters(self):
|
|
|
|
# return a matrix containing sampling indexes in the rows
|
|
|
|
return np.random.randint(0, len(self.data), 10*10).reshape(10, 10)
|
|
|
|
|
|
|
|
def sample(self, params):
|
|
|
|
index = np.unique(params)
|
|
|
|
return self.data.sampling_from_index(index)
|
|
|
|
|
|
|
|
p = NoSeedInit()
|
|
|
|
|
|
|
|
# this should raise a ValueError, since the class is said to be AbstractStochasticSeededProtocol but the
|
|
|
|
# random_seed has never been passed to super(NoSeedInit, self).__init__(random_seed)
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
for sample in p():
|
|
|
|
pass
|
|
|
|
print('done')
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|