forked from moreo/QuaPy
66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
import unittest
|
|
import numpy as np
|
|
from scipy.sparse import csr_matrix
|
|
|
|
import quapy as qp
|
|
|
|
|
|
class LabelCollectionTestCase(unittest.TestCase):
|
|
def test_split(self):
|
|
x = np.arange(100)
|
|
y = np.random.randint(0,5,100)
|
|
data = qp.data.LabelledCollection(x,y)
|
|
tr, te = data.split_random(0.7)
|
|
check_prev = tr.prevalence()*0.7 + te.prevalence()*0.3
|
|
|
|
self.assertEqual(len(tr), 70)
|
|
self.assertEqual(len(te), 30)
|
|
self.assertEqual(np.allclose(check_prev, data.prevalence()), True)
|
|
self.assertEqual(len(tr+te), len(data))
|
|
|
|
def test_join(self):
|
|
x = np.arange(50)
|
|
y = np.random.randint(2, 5, 50)
|
|
data1 = qp.data.LabelledCollection(x, y)
|
|
|
|
x = np.arange(200)
|
|
y = np.random.randint(0, 3, 200)
|
|
data2 = qp.data.LabelledCollection(x, y)
|
|
|
|
x = np.arange(100)
|
|
y = np.random.randint(0, 6, 100)
|
|
data3 = qp.data.LabelledCollection(x, y)
|
|
|
|
combined = qp.data.LabelledCollection.join(data1, data2, data3)
|
|
self.assertEqual(len(combined), len(data1)+len(data2)+len(data3))
|
|
self.assertEqual(all(combined.classes_ == np.arange(6)), True)
|
|
|
|
x = np.random.rand(10, 3)
|
|
y = np.random.randint(0, 1, 10)
|
|
data4 = qp.data.LabelledCollection(x, y)
|
|
with self.assertRaises(Exception):
|
|
combined = qp.data.LabelledCollection.join(data1, data2, data3, data4)
|
|
|
|
x = np.random.rand(20, 3)
|
|
y = np.random.randint(0, 1, 20)
|
|
data5 = qp.data.LabelledCollection(x, y)
|
|
combined = qp.data.LabelledCollection.join(data4, data5)
|
|
self.assertEqual(len(combined), len(data4)+len(data5))
|
|
|
|
x = np.random.rand(10, 4)
|
|
y = np.random.randint(0, 1, 10)
|
|
data6 = qp.data.LabelledCollection(x, y)
|
|
with self.assertRaises(Exception):
|
|
combined = qp.data.LabelledCollection.join(data4, data5, data6)
|
|
|
|
data4.instances = csr_matrix(data4.instances)
|
|
with self.assertRaises(Exception):
|
|
combined = qp.data.LabelledCollection.join(data4, data5)
|
|
data5.instances = csr_matrix(data5.instances)
|
|
combined = qp.data.LabelledCollection.join(data4, data5)
|
|
self.assertEqual(len(combined), len(data4) + len(data5))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|