2023-07-26 00:38:23 +02:00
|
|
|
import pytest
|
|
|
|
from quacc.data import ExClassManager as ECM, ExtendedCollection
|
|
|
|
import numpy as np
|
2023-07-27 03:16:41 +02:00
|
|
|
import scipy.sparse as sp
|
2023-07-26 00:38:23 +02:00
|
|
|
|
|
|
|
|
|
|
|
class TestExClassManager:
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"true_class,pred_class,result",
|
|
|
|
[
|
|
|
|
(0, 0, 0),
|
|
|
|
(0, 1, 1),
|
|
|
|
(1, 0, 2),
|
|
|
|
(1, 1, 3),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_get_ex(self, true_class, pred_class, result):
|
|
|
|
ncl = 2
|
|
|
|
assert ECM.get_ex(ncl, true_class, pred_class) == result
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"ex_class,result",
|
|
|
|
[
|
|
|
|
(0, 0),
|
|
|
|
(1, 1),
|
|
|
|
(2, 0),
|
|
|
|
(3, 1),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_get_pred(self, ex_class, result):
|
|
|
|
ncl = 2
|
|
|
|
assert ECM.get_pred(ncl, ex_class) == result
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"ex_class,result",
|
|
|
|
[
|
|
|
|
(0, 0),
|
|
|
|
(1, 0),
|
|
|
|
(2, 1),
|
|
|
|
(3, 1),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_get_true(self, ex_class, result):
|
|
|
|
ncl = 2
|
|
|
|
assert ECM.get_true(ncl, ex_class) == result
|
|
|
|
|
|
|
|
|
|
|
|
class TestExtendedCollection:
|
|
|
|
@pytest.mark.parametrize(
|
2023-07-27 03:16:41 +02:00
|
|
|
"instances,result",
|
2023-07-26 00:38:23 +02:00
|
|
|
[
|
|
|
|
(
|
2023-07-27 03:16:41 +02:00
|
|
|
np.asarray(
|
|
|
|
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
|
|
),
|
|
|
|
[np.asarray([1, 3]), np.asarray([0, 2])],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix(
|
|
|
|
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
|
|
),
|
|
|
|
[np.asarray([1, 3]), np.asarray([0, 2])],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
[np.asarray([], dtype=int), np.asarray([0, 1])],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
[np.asarray([], dtype=int), np.asarray([0, 1])],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
[np.asarray([0, 1]), np.asarray([], dtype=int)],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
[np.asarray([0, 1]), np.asarray([], dtype=int)],
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test__split_index_by_pred(self, instances, result):
|
|
|
|
ncl = 2
|
|
|
|
assert all(
|
|
|
|
np.array_equal(a, b)
|
|
|
|
for (a, b) in zip(
|
|
|
|
ExtendedCollection._split_index_by_pred(ncl, instances),
|
|
|
|
result,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"instances,s_inst,norms",
|
|
|
|
[
|
|
|
|
(
|
|
|
|
np.asarray(
|
|
|
|
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
|
|
),
|
|
|
|
[
|
|
|
|
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
],
|
|
|
|
[0.5, 0.5],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix(
|
|
|
|
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
|
|
),
|
|
|
|
[
|
|
|
|
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
],
|
|
|
|
[0.5, 0.5],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
[
|
|
|
|
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([], dtype=int),
|
|
|
|
],
|
|
|
|
[1.0, 0.0],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
[
|
|
|
|
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
sp.csr_matrix([], dtype=int),
|
|
|
|
],
|
|
|
|
[1.0, 0.0],
|
2023-07-26 00:38:23 +02:00
|
|
|
),
|
|
|
|
(
|
2023-07-27 03:16:41 +02:00
|
|
|
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
[
|
|
|
|
np.asarray([], dtype=int),
|
|
|
|
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
],
|
|
|
|
[0.0, 1.0],
|
2023-07-26 00:38:23 +02:00
|
|
|
),
|
|
|
|
(
|
2023-07-27 03:16:41 +02:00
|
|
|
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
[
|
|
|
|
sp.csr_matrix([], dtype=int),
|
|
|
|
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
],
|
|
|
|
[0.0, 1.0],
|
2023-07-26 00:38:23 +02:00
|
|
|
),
|
2023-07-27 03:16:41 +02:00
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_split_inst_by_pred(self, instances, s_inst, norms):
|
|
|
|
ncl = 2
|
|
|
|
_s_inst, _norms = ExtendedCollection.split_inst_by_pred(ncl, instances)
|
|
|
|
if isinstance(s_inst, np.ndarray):
|
|
|
|
assert all(np.array_equal(a, b) for (a, b) in zip(_s_inst, s_inst))
|
|
|
|
if isinstance(s_inst, sp.csr_matrix):
|
|
|
|
assert all((a != b).nnz == 0 for (a, b) in zip(_s_inst, s_inst))
|
|
|
|
assert all(a == b for (a, b) in zip(_norms, norms))
|
2023-07-26 00:38:23 +02:00
|
|
|
|
2023-07-27 03:16:41 +02:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"instances,labels,inst0,lbl0,inst1,lbl1",
|
|
|
|
[
|
|
|
|
(
|
|
|
|
np.asarray(
|
|
|
|
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
|
|
),
|
|
|
|
np.asarray([3, 0, 1, 2]),
|
|
|
|
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([0, 1]),
|
|
|
|
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
np.asarray([1, 0]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix(
|
|
|
|
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
|
|
),
|
|
|
|
np.asarray([3, 0, 1, 2]),
|
|
|
|
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([0, 1]),
|
|
|
|
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
np.asarray([1, 0]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
np.asarray([3, 1]),
|
|
|
|
np.asarray([], dtype=int),
|
|
|
|
np.asarray([], dtype=int),
|
|
|
|
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
np.asarray([1, 0]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
np.asarray([3, 1]),
|
|
|
|
sp.csr_matrix(np.empty((0, 0), dtype=int)),
|
|
|
|
np.asarray([], dtype=int),
|
|
|
|
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
np.asarray([1, 0]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([0, 2]),
|
|
|
|
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([0, 1]),
|
|
|
|
np.asarray([], dtype=int),
|
|
|
|
np.asarray([], dtype=int),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([0, 2]),
|
|
|
|
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([0, 1]),
|
|
|
|
sp.csr_matrix(np.empty((0, 0), dtype=int)),
|
|
|
|
np.asarray([], dtype=int),
|
|
|
|
),
|
2023-07-26 00:38:23 +02:00
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_split_by_pred(self, instances, labels, inst0, lbl0, inst1, lbl1):
|
2023-07-27 03:16:41 +02:00
|
|
|
ec = ExtendedCollection(instances, labels, classes=range(0, 4))
|
2023-07-26 00:38:23 +02:00
|
|
|
[ec0, ec1] = ec.split_by_pred()
|
2023-07-27 03:16:41 +02:00
|
|
|
if isinstance(instances, np.ndarray):
|
|
|
|
assert np.array_equal(ec0.X, inst0)
|
|
|
|
assert np.array_equal(ec1.X, inst1)
|
|
|
|
if isinstance(instances, sp.csr_matrix):
|
|
|
|
assert (ec0.X != inst0).nnz == 0
|
|
|
|
assert (ec1.X != inst1).nnz == 0
|
|
|
|
assert np.array_equal(ec0.y, lbl0)
|
|
|
|
assert np.array_equal(ec1.y, lbl1)
|