added tests for data.py
This commit is contained in:
parent
1d5507889b
commit
de48da638a
|
@ -2,184 +2,286 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import scipy.sparse as sp
|
import scipy.sparse as sp
|
||||||
|
|
||||||
from quacc.data import ExtendedCollection
|
from quacc.data import (
|
||||||
|
ExtendedCollection,
|
||||||
|
ExtendedData,
|
||||||
|
ExtendedLabels,
|
||||||
|
ExtendedPrev,
|
||||||
|
ExtensionPolicy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestExtendedCollection:
|
@pytest.mark.ext
|
||||||
|
@pytest.mark.extpol
|
||||||
|
class TestExtendedPolicy:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"instances,result",
|
"extpol,nbcl,result",
|
||||||
[
|
[
|
||||||
(
|
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
|
||||||
np.asarray(
|
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2])),
|
||||||
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
|
||||||
),
|
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3])),
|
||||||
[np.asarray([1, 3]), np.asarray([0, 2])],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
sp.csr_matrix(
|
|
||||||
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
||||||
),
|
|
||||||
[np.asarray([1, 3]), np.asarray([0, 2])],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
[np.asarray([], dtype=int), np.asarray([0, 1])],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
[np.asarray([], dtype=int), np.asarray([0, 1])],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
[np.asarray([0, 1]), np.asarray([], dtype=int)],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
[np.asarray([0, 1]), np.asarray([], dtype=int)],
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test__split_index_by_pred(self, instances, result):
|
def test_qclasses(self, extpol, nbcl, result):
|
||||||
ncl = 2
|
assert (result == extpol.qclasses(nbcl)).all()
|
||||||
assert all(
|
|
||||||
np.array_equal(a, b)
|
|
||||||
for (a, b) in zip(
|
|
||||||
ExtendedCollection._split_index_by_pred(ncl, instances),
|
|
||||||
result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"instances,s_inst,norms",
|
"extpol,nbcl,result",
|
||||||
[
|
[
|
||||||
|
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
|
||||||
|
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2, 3])),
|
||||||
|
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
|
||||||
(
|
(
|
||||||
np.asarray(
|
ExtensionPolicy(collapse_false=True),
|
||||||
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
3,
|
||||||
),
|
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]),
|
||||||
[
|
|
||||||
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
],
|
|
||||||
[0.5, 0.5],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
sp.csr_matrix(
|
|
||||||
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
||||||
),
|
|
||||||
[
|
|
||||||
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
],
|
|
||||||
[0.5, 0.5],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
[
|
|
||||||
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
np.asarray([], dtype=int),
|
|
||||||
],
|
|
||||||
[1.0, 0.0],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
[
|
|
||||||
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
sp.csr_matrix([], dtype=int),
|
|
||||||
],
|
|
||||||
[1.0, 0.0],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
[
|
|
||||||
np.asarray([], dtype=int),
|
|
||||||
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
],
|
|
||||||
[0.0, 1.0],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
[
|
|
||||||
sp.csr_matrix([], dtype=int),
|
|
||||||
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
],
|
|
||||||
[0.0, 1.0],
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_split_inst_by_pred(self, instances, s_inst, norms):
|
def test_eclasses(self, extpol, nbcl, result):
|
||||||
ncl = 2
|
assert (result == extpol.eclasses(nbcl)).all()
|
||||||
_s_inst, _norms = ExtendedCollection.split_inst_by_pred(ncl, instances)
|
|
||||||
if isinstance(s_inst, np.ndarray):
|
|
||||||
assert all(np.array_equal(a, b) for (a, b) in zip(_s_inst, s_inst))
|
|
||||||
if isinstance(s_inst, sp.csr_matrix):
|
|
||||||
assert all((a != b).nnz == 0 for (a, b) in zip(_s_inst, s_inst))
|
|
||||||
assert all(a == b for (a, b) in zip(_norms, norms))
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"instances,labels,inst0,lbl0,inst1,lbl1",
|
"extpol,nbcl,result",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
np.asarray(
|
ExtensionPolicy(),
|
||||||
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
2,
|
||||||
|
(
|
||||||
|
np.array([0, 0, 1, 1]),
|
||||||
|
np.array([0, 1, 0, 1]),
|
||||||
),
|
),
|
||||||
np.asarray([3, 0, 1, 2]),
|
|
||||||
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
np.asarray([0, 1]),
|
|
||||||
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
np.asarray([1, 0]),
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
sp.csr_matrix(
|
ExtensionPolicy(collapse_false=True),
|
||||||
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
2,
|
||||||
|
(
|
||||||
|
np.array([0, 1, 0]),
|
||||||
|
np.array([0, 1, 1]),
|
||||||
),
|
),
|
||||||
np.asarray([3, 0, 1, 2]),
|
|
||||||
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
np.asarray([0, 1]),
|
|
||||||
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
np.asarray([1, 0]),
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
ExtensionPolicy(),
|
||||||
np.asarray([3, 1]),
|
3,
|
||||||
np.asarray([], dtype=int),
|
(
|
||||||
np.asarray([], dtype=int),
|
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
|
||||||
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]),
|
||||||
np.asarray([1, 0]),
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
ExtensionPolicy(collapse_false=True),
|
||||||
np.asarray([3, 1]),
|
3,
|
||||||
sp.csr_matrix(np.empty((0, 0), dtype=int)),
|
|
||||||
np.asarray([], dtype=int),
|
|
||||||
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
||||||
np.asarray([1, 0]),
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
np.array([0, 1, 2, 0]),
|
||||||
np.asarray([0, 2]),
|
np.array([0, 1, 2, 1]),
|
||||||
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
np.asarray([0, 1]),
|
|
||||||
np.asarray([], dtype=int),
|
|
||||||
np.asarray([], dtype=int),
|
|
||||||
),
|
),
|
||||||
(
|
|
||||||
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
np.asarray([0, 2]),
|
|
||||||
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
||||||
np.asarray([0, 1]),
|
|
||||||
sp.csr_matrix(np.empty((0, 0), dtype=int)),
|
|
||||||
np.asarray([], dtype=int),
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_split_by_pred(self, instances, labels, inst0, lbl0, inst1, lbl1):
|
def test_matrix_idx(self, extpol, nbcl, result):
|
||||||
ec = ExtendedCollection(instances, labels, classes=range(0, 4))
|
_midx = extpol.matrix_idx(nbcl)
|
||||||
[ec0, ec1] = ec.split_by_pred()
|
assert len(_midx) == len(result)
|
||||||
if isinstance(instances, np.ndarray):
|
assert all((idx == r).all() for idx, r in zip(_midx, result))
|
||||||
assert np.array_equal(ec0.X, inst0)
|
|
||||||
assert np.array_equal(ec1.X, inst1)
|
@pytest.mark.parametrize(
|
||||||
if isinstance(instances, sp.csr_matrix):
|
"extpol,nbcl,true,pred,result",
|
||||||
assert (ec0.X != inst0).nnz == 0
|
[
|
||||||
assert (ec1.X != inst1).nnz == 0
|
(
|
||||||
assert np.array_equal(ec0.y, lbl0)
|
ExtensionPolicy(),
|
||||||
assert np.array_equal(ec1.y, lbl1)
|
2,
|
||||||
|
np.array([1, 0, 1, 1, 0, 0]),
|
||||||
|
np.array([1, 0, 0, 1, 1, 0]),
|
||||||
|
np.array([3, 0, 2, 3, 1, 0]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ExtensionPolicy(collapse_false=True),
|
||||||
|
2,
|
||||||
|
np.array([1, 0, 1, 1, 0, 0]),
|
||||||
|
np.array([1, 0, 0, 1, 1, 0]),
|
||||||
|
np.array([1, 0, 2, 1, 2, 0]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ExtensionPolicy(),
|
||||||
|
3,
|
||||||
|
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||||
|
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||||
|
np.array([4, 6, 0, 3, 1, 7, 2, 5, 8]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ExtensionPolicy(collapse_false=True),
|
||||||
|
3,
|
||||||
|
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||||
|
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||||
|
np.array([1, 3, 0, 3, 3, 3, 3, 3, 2]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_ext_lbl(self, extpol, nbcl, true, pred, result):
|
||||||
|
vfun = extpol.ext_lbl(nbcl)
|
||||||
|
assert (vfun(true, pred) == result).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.ext
|
||||||
|
@pytest.mark.extd
|
||||||
|
class TestExtendedData:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pred_proba,result",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
np.array([[0.3, 0.7], [0.54, 0.46], [0.28, 0.72], [0.6, 0.4]]),
|
||||||
|
[np.array([1, 3]), np.array([0, 2])],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([[0.3, 0.7], [0.28, 0.72]]),
|
||||||
|
[np.array([]), np.array([0, 1])],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([[0.54, 0.46], [0.6, 0.4]]),
|
||||||
|
[np.array([0, 1]), np.array([])],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[0.25, 0.4, 0.35],
|
||||||
|
[0.24, 0.3, 0.46],
|
||||||
|
[0.61, 0.28, 0.11],
|
||||||
|
[0.4, 0.1, 0.5],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
[np.array([2]), np.array([0]), np.array([1, 3])],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__split_index_by_pred(self, monkeypatch, pred_proba, result):
|
||||||
|
def mockinit(self, pred_proba):
|
||||||
|
self.pred_proba_ = pred_proba
|
||||||
|
|
||||||
|
monkeypatch.setattr(ExtendedData, "__init__", mockinit)
|
||||||
|
ed = ExtendedData(pred_proba)
|
||||||
|
_split_index = ed._ExtendedData__split_index_by_pred()
|
||||||
|
assert len(_split_index) == len(result)
|
||||||
|
assert all((a == b).all() for (a, b) in zip(_split_index, result))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.ext
|
||||||
|
@pytest.mark.extl
|
||||||
|
class TestExtendedLabels:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"true,pred,nbcl,extpol,result",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
np.array([1, 0, 0, 1, 1]),
|
||||||
|
np.array([1, 1, 0, 0, 1]),
|
||||||
|
2,
|
||||||
|
ExtensionPolicy(),
|
||||||
|
np.array([3, 1, 0, 2, 3]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([1, 0, 0, 1, 1]),
|
||||||
|
np.array([1, 1, 0, 0, 1]),
|
||||||
|
2,
|
||||||
|
ExtensionPolicy(collapse_false=True),
|
||||||
|
np.array([1, 2, 0, 2, 1]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_y(self, true, pred, nbcl, extpol, result):
|
||||||
|
el = ExtendedLabels(true, pred, nbcl, extpol)
|
||||||
|
assert (el.y == result).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.ext
|
||||||
|
@pytest.mark.extp
|
||||||
|
class TestExtendedPrev:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"flat,nbcl,extpol,q_classes,result",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
np.array([0.2, 0, 0.8, 0]),
|
||||||
|
2,
|
||||||
|
ExtensionPolicy(),
|
||||||
|
[0, 1, 2, 3],
|
||||||
|
np.array([0.2, 0, 0.8, 0]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.2, 0.8]),
|
||||||
|
2,
|
||||||
|
ExtensionPolicy(),
|
||||||
|
[0, 3],
|
||||||
|
np.array([0.2, 0, 0, 0.8]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.2, 0.8]),
|
||||||
|
2,
|
||||||
|
ExtensionPolicy(collapse_false=True),
|
||||||
|
[0, 2],
|
||||||
|
np.array([0.2, 0, 0.8]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.1, 0.1, 0.6, 0.2]),
|
||||||
|
3,
|
||||||
|
ExtensionPolicy(),
|
||||||
|
[0, 1, 3, 5],
|
||||||
|
np.array([0.1, 0.1, 0, 0.6, 0, 0.2, 0, 0, 0]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.1, 0.1, 0.6]),
|
||||||
|
3,
|
||||||
|
ExtensionPolicy(collapse_false=True),
|
||||||
|
[0, 1, 2],
|
||||||
|
np.array([0.1, 0.1, 0.6, 0]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result):
|
||||||
|
def mockinit(self, flat, nbcl, extpol):
|
||||||
|
self.flat = flat
|
||||||
|
self.nbcl = nbcl
|
||||||
|
self.extpol = extpol
|
||||||
|
|
||||||
|
monkeypatch.setattr(ExtendedPrev, "__init__", mockinit)
|
||||||
|
ep = ExtendedPrev(flat, nbcl, extpol)
|
||||||
|
ep._ExtendedPrev__check_q_classes(q_classes)
|
||||||
|
assert (ep.flat == result).all()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"flat,nbcl,extpol,result",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
np.array([0.05, 0.1, 0.6, 0.25]),
|
||||||
|
2,
|
||||||
|
ExtensionPolicy(),
|
||||||
|
np.array([[0.05, 0.1], [0.6, 0.25]]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.05, 0.1, 0.85]),
|
||||||
|
2,
|
||||||
|
ExtensionPolicy(collapse_false=True),
|
||||||
|
np.array([[0.05, 0.85], [0, 0.1]]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.05, 0.1, 0.2, 0.15, 0.04, 0.06, 0.15, 0.14, 0.1]),
|
||||||
|
3,
|
||||||
|
ExtensionPolicy(),
|
||||||
|
np.array([[0.05, 0.1, 0.2], [0.15, 0.04, 0.06], [0.15, 0.14, 0.1]]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.05, 0.2, 0.65, 0.1]),
|
||||||
|
3,
|
||||||
|
ExtensionPolicy(collapse_false=True),
|
||||||
|
np.array([[0.05, 0.1, 0], [0, 0.2, 0], [0, 0, 0.65]]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__build_matrix(self, monkeypatch, flat, nbcl, extpol, result):
|
||||||
|
def mockinit(self, flat, nbcl, extpol):
|
||||||
|
self.flat = flat
|
||||||
|
self.nbcl = nbcl
|
||||||
|
self.extpol = extpol
|
||||||
|
|
||||||
|
monkeypatch.setattr(ExtendedPrev, "__init__", mockinit)
|
||||||
|
ep = ExtendedPrev(flat, nbcl, extpol)
|
||||||
|
_matrix = ep._ExtendedPrev__build_matrix()
|
||||||
|
assert _matrix.shape == result.shape
|
||||||
|
assert (_matrix == result).all()
|
||||||
|
|
Loading…
Reference in New Issue