QuaPy/quapy/tests/test_labelcollection.py

66 lines
2.3 KiB
Python
Raw Normal View History

2023-02-13 19:27:48 +01:00
import unittest
import numpy as np
2023-02-14 17:00:50 +01:00
from scipy.sparse import csr_matrix
2023-02-13 19:27:48 +01:00
import quapy as qp
class LabelCollectionTestCase(unittest.TestCase):
def test_split(self):
x = np.arange(100)
y = np.random.randint(0,5,100)
data = qp.data.LabelledCollection(x,y)
tr, te = data.split_random(0.7)
check_prev = tr.prevalence()*0.7 + te.prevalence()*0.3
self.assertEqual(len(tr), 70)
self.assertEqual(len(te), 30)
self.assertEqual(np.allclose(check_prev, data.prevalence()), True)
self.assertEqual(len(tr+te), len(data))
2023-02-14 17:00:50 +01:00
def test_join(self):
x = np.arange(50)
y = np.random.randint(2, 5, 50)
data1 = qp.data.LabelledCollection(x, y)
x = np.arange(200)
y = np.random.randint(0, 3, 200)
data2 = qp.data.LabelledCollection(x, y)
x = np.arange(100)
y = np.random.randint(0, 6, 100)
data3 = qp.data.LabelledCollection(x, y)
combined = qp.data.LabelledCollection.join(data1, data2, data3)
self.assertEqual(len(combined), len(data1)+len(data2)+len(data3))
self.assertEqual(all(combined.classes_ == np.arange(6)), True)
x = np.random.rand(10, 3)
y = np.random.randint(0, 1, 10)
data4 = qp.data.LabelledCollection(x, y)
with self.assertRaises(Exception):
combined = qp.data.LabelledCollection.join(data1, data2, data3, data4)
x = np.random.rand(20, 3)
y = np.random.randint(0, 1, 20)
data5 = qp.data.LabelledCollection(x, y)
combined = qp.data.LabelledCollection.join(data4, data5)
self.assertEqual(len(combined), len(data4)+len(data5))
x = np.random.rand(10, 4)
y = np.random.randint(0, 1, 10)
data6 = qp.data.LabelledCollection(x, y)
with self.assertRaises(Exception):
combined = qp.data.LabelledCollection.join(data4, data5, data6)
data4.instances = csr_matrix(data4.instances)
with self.assertRaises(Exception):
combined = qp.data.LabelledCollection.join(data4, data5)
data5.instances = csr_matrix(data5.instances)
combined = qp.data.LabelledCollection.join(data4, data5)
self.assertEqual(len(combined), len(data4) + len(data5))
2023-02-13 19:27:48 +01:00
if __name__ == '__main__':
unittest.main()