tests added, passing

This commit is contained in:
Lorenzo Volpi 2023-12-21 16:31:06 +01:00
parent b239b2e38a
commit db5064dbaf
11 changed files with 1207 additions and 248 deletions

View File

@ -1,130 +1,33 @@
from unittest import mock
import numpy as np
import pytest
import scipy.sparse as sp
from quacc.data import (
ExtBinPrev,
ExtendedCollection,
ExtendedData,
ExtendedLabels,
ExtendedPrev,
ExtensionPolicy,
ExtMulPrev,
_split_index_by_pred,
)
@pytest.mark.ext
@pytest.mark.extpol
class TestExtendedPolicy:
@pytest.mark.parametrize(
"extpol,nbcl,result",
[
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2])),
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3])),
],
)
def test_qclasses(self, extpol, nbcl, result):
assert (result == extpol.qclasses(nbcl)).all()
@pytest.fixture
def nd_1():
return np.arange(12).reshape((4, 3))
@pytest.mark.parametrize(
"extpol,nbcl,result",
[
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2, 3])),
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
(
ExtensionPolicy(collapse_false=True),
3,
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]),
),
],
)
def test_eclasses(self, extpol, nbcl, result):
assert (result == extpol.eclasses(nbcl)).all()
@pytest.mark.parametrize(
"extpol,nbcl,result",
[
(
ExtensionPolicy(),
2,
(
np.array([0, 0, 1, 1]),
np.array([0, 1, 0, 1]),
),
),
(
ExtensionPolicy(collapse_false=True),
2,
(
np.array([0, 1, 0]),
np.array([0, 1, 1]),
),
),
(
ExtensionPolicy(),
3,
(
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]),
),
),
(
ExtensionPolicy(collapse_false=True),
3,
(
np.array([0, 1, 2, 0]),
np.array([0, 1, 2, 1]),
),
),
],
)
def test_matrix_idx(self, extpol, nbcl, result):
_midx = extpol.matrix_idx(nbcl)
assert len(_midx) == len(result)
assert all((idx == r).all() for idx, r in zip(_midx, result))
@pytest.mark.parametrize(
"extpol,nbcl,true,pred,result",
[
(
ExtensionPolicy(),
2,
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
np.array([3, 0, 2, 3, 1, 0]),
),
(
ExtensionPolicy(collapse_false=True),
2,
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
np.array([1, 0, 2, 1, 2, 0]),
),
(
ExtensionPolicy(),
3,
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
np.array([4, 6, 0, 3, 1, 7, 2, 5, 8]),
),
(
ExtensionPolicy(collapse_false=True),
3,
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
np.array([1, 3, 0, 3, 3, 3, 3, 3, 2]),
),
],
)
def test_ext_lbl(self, extpol, nbcl, true, pred, result):
vfun = extpol.ext_lbl(nbcl)
assert (vfun(true, pred) == result).all()
@pytest.fixture
def csr_1(nd_1):
return sp.csr_matrix(nd_1)
@pytest.mark.ext
@pytest.mark.extd
class TestExtendedData:
class TestData:
@pytest.mark.parametrize(
"pred_proba,result",
[
@ -153,17 +56,219 @@ class TestExtendedData:
),
],
)
def test__split_index_by_pred(self, monkeypatch, pred_proba, result):
def mockinit(self, pred_proba):
self.pred_proba_ = pred_proba
monkeypatch.setattr(ExtendedData, "__init__", mockinit)
ed = ExtendedData(pred_proba)
_split_index = ed._ExtendedData__split_index_by_pred()
def test_split_index_by_pred(self, pred_proba, result):
_split_index = _split_index_by_pred(pred_proba)
assert len(_split_index) == len(result)
assert all((a == b).all() for (a, b) in zip(_split_index, result))
@pytest.mark.ext
@pytest.mark.extpol
class TestExtendedPolicy:
# fmt: off
@pytest.mark.parametrize(
"extpol,nbcl,result",
[
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
(ExtensionPolicy(group_false=True), 2, np.array([0, 1, 2, 3])),
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2])),
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
(ExtensionPolicy(group_false=True), 3, np.array([0, 1, 2, 3, 4, 5])),
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3])),
],
)
def test_qclasses(self, extpol, nbcl, result):
assert (result == extpol.qclasses(nbcl)).all()
@pytest.mark.parametrize(
"extpol,nbcl,result",
[
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
(ExtensionPolicy(group_false=True), 2, np.array([0, 1, 2, 3])),
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2, 3])),
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
(ExtensionPolicy(group_false=True), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
],
)
def test_eclasses(self, extpol, nbcl, result):
assert (result == extpol.eclasses(nbcl)).all()
@pytest.mark.parametrize(
"extpol,nbcl,result",
[
(ExtensionPolicy(), 2, np.array([0, 1])),
(ExtensionPolicy(group_false=True), 2, np.array([0, 1])),
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1])),
(ExtensionPolicy(), 3, np.array([0, 1, 2])),
(ExtensionPolicy(group_false=True), 3, np.array([0, 1])),
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2])),
],
)
def test_tfp_classes(self, extpol, nbcl, result):
assert (result == extpol.tfp_classes(nbcl)).all()
@pytest.mark.parametrize(
"extpol,nbcl,result",
[
(
ExtensionPolicy(), 2,
(np.array([0, 0, 1, 1]), np.array([0, 1, 0, 1])),
),
(
ExtensionPolicy(group_false=True), 2,
(np.array([0, 1, 1, 0]), np.array([0, 1, 0, 1])),
),
(
ExtensionPolicy(collapse_false=True), 2,
(np.array([0, 1, 0]), np.array([0, 1, 1])),
),
(
ExtensionPolicy(), 3,
(np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]), np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])),
),
(
ExtensionPolicy(group_false=True), 3,
(np.array([0, 1, 2, 1, 2, 0]), np.array([0, 1, 2, 0, 1, 2])),
),
(
ExtensionPolicy(collapse_false=True), 3,
(np.array([0, 1, 2, 0]), np.array([0, 1, 2, 1])),
),
],
)
def test_matrix_idx(self, extpol, nbcl, result):
_midx = extpol.matrix_idx(nbcl)
assert len(_midx) == len(result)
assert all((idx == r).all() for idx, r in zip(_midx, result))
@pytest.mark.parametrize(
"extpol,nbcl,true,pred,result",
[
(
ExtensionPolicy(), 2,
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
np.array([3, 0, 2, 3, 1, 0]),
),
(
ExtensionPolicy(group_false=True), 2,
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
np.array([1, 0, 2, 1, 3, 0]),
),
(
ExtensionPolicy(collapse_false=True), 2,
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
np.array([1, 0, 2, 1, 2, 0]),
),
(
ExtensionPolicy(), 3,
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
np.array([4, 6, 0, 3, 1, 7, 2, 5, 8]),
),
(
ExtensionPolicy(group_false=True), 3,
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
np.array([1, 3, 0, 3, 4, 4, 5, 5, 2]),
),
(
ExtensionPolicy(collapse_false=True), 3,
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
np.array([1, 3, 0, 3, 3, 3, 3, 3, 2]),
),
],
)
def test_ext_lbl(self, extpol, nbcl, true, pred, result):
vfun = extpol.ext_lbl(nbcl)
assert (vfun(true, pred) == result).all()
@pytest.mark.parametrize(
"extpol,nbcl,true,pred,result",
[
(
ExtensionPolicy(), 2,
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
np.array([1, 0, 1, 1, 0, 0]),
),
(
ExtensionPolicy(group_false=True), 2,
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
np.array([0, 0, 1, 0, 1, 0]),
),
(
ExtensionPolicy(collapse_false=True), 2,
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
np.array([1, 0, 1, 1, 0, 0]),
),
(
ExtensionPolicy(), 3,
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
),
(
ExtensionPolicy(group_false=True), 3,
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
np.array([0, 1, 0, 1, 1, 1, 1, 1, 0]),
),
(
ExtensionPolicy(collapse_false=True), 3,
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
),
],
)
def test_true_lbl_from_pred(self, extpol, nbcl, true, pred, result):
vfun = extpol.true_lbl_from_pred(nbcl)
assert (vfun(true, pred) == result).all()
# fmt: on
@pytest.mark.ext
@pytest.mark.extd
class TestExtendedData:
@pytest.mark.parametrize(
"instances_name,indexes,result",
[
(
"nd_1",
[np.array([0, 2]), np.array([1, 3])],
[
np.array([[0, 1, 2], [6, 7, 8]]),
np.array([[3, 4, 5], [9, 10, 11]]),
],
),
(
"nd_1",
[np.array([0]), np.array([1, 3]), np.array([2])],
[
np.array([[0, 1, 2]]),
np.array([[3, 4, 5], [9, 10, 11]]),
np.array([[6, 7, 8]]),
],
),
],
)
def test_split_by_pred(self, instances_name, indexes, result, monkeypatch, request):
def mockinit(self):
self.instances = request.getfixturevalue(instances_name)
monkeypatch.setattr(ExtendedData, "__init__", mockinit)
d = ExtendedData()
split = d.split_by_pred(indexes)
assert all([(s == r).all() for s, r in zip(split, result)])
@pytest.mark.ext
@pytest.mark.extl
class TestExtendedLabels:
@ -177,6 +282,13 @@ class TestExtendedLabels:
ExtensionPolicy(),
np.array([3, 1, 0, 2, 3]),
),
(
np.array([1, 0, 0, 1, 1]),
np.array([1, 1, 0, 0, 1]),
2,
ExtensionPolicy(group_false=True),
np.array([1, 3, 0, 2, 1]),
),
(
np.array([1, 0, 0, 1, 1]),
np.array([1, 1, 0, 0, 1]),
@ -184,92 +296,128 @@ class TestExtendedLabels:
ExtensionPolicy(collapse_false=True),
np.array([1, 2, 0, 2, 1]),
),
(
np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]),
np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]),
3,
ExtensionPolicy(),
np.array([4, 1, 0, 3, 2, 5, 8, 6, 7]),
),
(
np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]),
np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]),
3,
ExtensionPolicy(group_false=True),
np.array([1, 4, 0, 3, 5, 5, 2, 3, 4]),
),
(
np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]),
np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]),
3,
ExtensionPolicy(collapse_false=True),
np.array([1, 3, 0, 3, 3, 3, 2, 3, 3]),
),
],
)
def test_y(self, true, pred, nbcl, extpol, result):
el = ExtendedLabels(true, pred, nbcl, extpol)
assert (el.y == result).all()
@pytest.mark.parametrize(
"extpol,nbcl,indexes,true,pred,result,rcls",
[
(
ExtensionPolicy(),
2,
[np.array([1, 2, 5]), np.array([0, 3, 4])],
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
[np.array([0, 1, 0]), np.array([1, 1, 0])],
np.array([0, 1]),
),
(
ExtensionPolicy(group_false=True),
2,
[np.array([1, 2, 5]), np.array([0, 3, 4])],
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
[np.array([0, 1, 0]), np.array([0, 0, 1])],
np.array([0, 1]),
),
(
ExtensionPolicy(collapse_false=True),
2,
[np.array([1, 2, 5]), np.array([0, 3, 4])],
np.array([1, 0, 1, 1, 0, 0]),
np.array([1, 0, 0, 1, 1, 0]),
[np.array([0, 1, 0]), np.array([1, 1, 0])],
np.array([0, 1]),
),
(
ExtensionPolicy(),
3,
[np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])],
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
[np.array([2, 0, 1]), np.array([1, 0, 2]), np.array([0, 1, 2])],
np.array([0, 1, 2]),
),
(
ExtensionPolicy(group_false=True),
3,
[np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])],
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
[np.array([1, 0, 1]), np.array([0, 1, 1]), np.array([1, 1, 0])],
np.array([0, 1]),
),
(
ExtensionPolicy(collapse_false=True),
3,
[np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])],
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
[np.array([2, 0, 1]), np.array([1, 0, 2]), np.array([0, 1, 2])],
np.array([0, 1, 2]),
),
],
)
def test_split_by_pred(self, extpol, nbcl, indexes, true, pred, result, rcls):
el = ExtendedLabels(true, pred, nbcl, extpol)
labels, cls = el.split_by_pred(indexes)
assert (cls == rcls).all()
assert all([(lbl == r).all() for lbl, r in zip(labels, result)])
@pytest.mark.ext
@pytest.mark.extp
class TestExtendedPrev:
@pytest.mark.parametrize(
"flat,nbcl,extpol,q_classes,result",
[
(
np.array([0.2, 0, 0.8, 0]),
2,
ExtensionPolicy(),
[0, 1, 2, 3],
np.array([0.2, 0, 0.8, 0]),
),
(
np.array([0.2, 0.8]),
2,
ExtensionPolicy(),
[0, 3],
np.array([0.2, 0, 0, 0.8]),
),
(
np.array([0.2, 0.8]),
2,
ExtensionPolicy(collapse_false=True),
[0, 2],
np.array([0.2, 0, 0.8]),
),
(
np.array([0.1, 0.1, 0.6, 0.2]),
3,
ExtensionPolicy(),
[0, 1, 3, 5],
np.array([0.1, 0.1, 0, 0.6, 0, 0.2, 0, 0, 0]),
),
(
np.array([0.1, 0.1, 0.6]),
3,
ExtensionPolicy(collapse_false=True),
[0, 1, 2],
np.array([0.1, 0.1, 0.6, 0]),
),
],
)
def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result):
def mockinit(self, flat, nbcl, extpol):
self.flat = flat
self.nbcl = nbcl
self.extpol = extpol
monkeypatch.setattr(ExtendedPrev, "__init__", mockinit)
ep = ExtendedPrev(flat, nbcl, extpol)
ep._ExtendedPrev__check_q_classes(q_classes)
assert (ep.flat == result).all()
# fmt: off
@pytest.mark.parametrize(
"flat,nbcl,extpol,result",
[
(
np.array([0.05, 0.1, 0.6, 0.25]),
2,
ExtensionPolicy(),
np.array([0.05, 0.1, 0.6, 0.25]), 2, ExtensionPolicy(),
np.array([[0.05, 0.1], [0.6, 0.25]]),
),
(
np.array([0.05, 0.1, 0.85]),
2,
ExtensionPolicy(collapse_false=True),
np.array([0.05, 0.1, 0.6, 0.25]), 2, ExtensionPolicy(group_false=True),
np.array([[0.05, 0.25], [0.6, 0.1]]),
),
(
np.array([0.05, 0.1, 0.85]), 2, ExtensionPolicy(collapse_false=True),
np.array([[0.05, 0.85], [0, 0.1]]),
),
(
np.array([0.05, 0.1, 0.2, 0.15, 0.04, 0.06, 0.15, 0.14, 0.1]),
3,
ExtensionPolicy(),
np.array([0.05, 0.1, 0.2, 0.15, 0.04, 0.06, 0.15, 0.14, 0.1]), 3, ExtensionPolicy(),
np.array([[0.05, 0.1, 0.2], [0.15, 0.04, 0.06], [0.15, 0.14, 0.1]]),
),
(
np.array([0.05, 0.2, 0.65, 0.1]),
3,
ExtensionPolicy(collapse_false=True),
np.array([0.15, 0.2, 0.15, 0.1, 0.15, 0.25]), 3, ExtensionPolicy(group_false=True),
np.array([[0.15, 0.0, 0.25], [0.1, 0.2, 0.0], [0.0, 0.15, 0.15]]),
),
(
np.array([0.05, 0.2, 0.65, 0.1]), 3, ExtensionPolicy(collapse_false=True),
np.array([[0.05, 0.1, 0], [0, 0.2, 0], [0, 0, 0.65]]),
),
],
@ -285,3 +433,130 @@ class TestExtendedPrev:
_matrix = ep._ExtendedPrev__build_matrix()
assert _matrix.shape == result.shape
assert (_matrix == result).all()
# fmt: on
@pytest.mark.ext
@pytest.mark.extp
class TestExtMulPrev:
# fmt: off
@pytest.mark.parametrize(
"flat,nbcl,extpol,q_classes,result",
[
(np.array([0.2, 0, 0.8, 0]), 2, ExtensionPolicy(), [0, 1, 2, 3], np.array([0.2, 0, 0.8, 0])),
(np.array([0.2, 0.8]), 2, ExtensionPolicy(), [0, 3], np.array([0.2, 0, 0, 0.8])),
(np.array([0.2, 0.8]), 2, ExtensionPolicy(group_false=True), [0, 3], np.array([0.2, 0, 0, 0.8])),
(np.array([0.2, 0.8]), 2, ExtensionPolicy(collapse_false=True), [0, 2], np.array([0.2, 0, 0.8])),
(np.array([0.1, 0.1, 0.6, 0.2]), 3, ExtensionPolicy(), [0, 1, 3, 5], np.array([0.1, 0.1, 0, 0.6, 0, 0.2, 0, 0, 0])),
(np.array([0.1, 0.1, 0.6, 0.2]), 3, ExtensionPolicy(group_false=True), [0, 1, 3, 5], np.array([0.1, 0.1, 0, 0.6, 0, 0.2])),
(np.array([0.1, 0.1, 0.6]), 3, ExtensionPolicy(collapse_false=True), [0, 1, 2], np.array([0.1, 0.1, 0.6, 0])),
],
)
def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result):
def mockinit(self, nbcl, extpol):
self.nbcl = nbcl
self.extpol = extpol
monkeypatch.setattr(ExtMulPrev, "__init__", mockinit)
ep = ExtMulPrev(nbcl, extpol)
_flat = ep._ExtMulPrev__check_q_classes(q_classes, flat)
assert (_flat == result).all()
# fmt: on
@pytest.mark.ext
@pytest.mark.extp
class TestExtBinPrev:
# fmt: off
@pytest.mark.parametrize(
"flat,nbcl,extpol,q_classes,result",
[
([np.array([0.2, 0]), np.array([0.8, 0])], 2, ExtensionPolicy(), [[0, 1], [0, 1]], np.array([[0.2, 0], [0.8, 0]])),
([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])),
([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(group_false=True), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])),
([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(collapse_false=True), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])),
([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(), [[0, 1], [0], [2]], np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]])),
([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(group_false=True), [[0, 1], [0], [1]], np.array([[0.1, 0.1], [0.6, 0], [0, 0.2]])),
([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(collapse_false=True), [[0, 1], [0], [2]], np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]])),
],
)
def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result):
def mockinit(self, nbcl, extpol):
self.nbcl = nbcl
self.extpol = extpol
monkeypatch.setattr(ExtBinPrev, "__init__", mockinit)
ep = ExtBinPrev(nbcl, extpol)
_flat = ep._ExtBinPrev__check_q_classes(q_classes, flat)
assert (_flat == result).all()
@pytest.mark.parametrize(
"flat,result",
[
(np.array([[0.2, 0], [0.8, 0]]), np.array([0.2, 0.8, 0, 0])),
(np.array([[0.2, 0], [0, 0.8]]), np.array([0.2, 0, 0, 0.8])),
(np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]]), np.array([0.1, 0.6, 0, 0.1, 0, 0, 0, 0, 0.2])),
(np.array([[0.1, 0.1], [0.6, 0], [0, 0.2]]), np.array([0.1, 0.6, 0, 0.1, 0, 0.2])),
],
)
def test__build_flat(self, monkeypatch, flat, result):
def mockinit(self):
pass
monkeypatch.setattr(ExtBinPrev, "__init__", mockinit)
ep = ExtBinPrev()
_flat = ep._ExtBinPrev__build_flat(flat)
assert (_flat == result).all()
# fmt: on
@pytest.mark.ext
@pytest.mark.extc
class TestExtendedCollection:
@pytest.mark.parametrize(
"instances_name,labels,pred_proba,extpol,result",
[
(
"nd_1",
np.array([0, 1, 1, 0]),
np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]),
ExtensionPolicy(),
np.array([0, 0.5, 0.25, 0.25]),
),
(
"nd_1",
np.array([0, 1, 1, 0]),
np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]),
ExtensionPolicy(collapse_false=True),
np.array([0, 0.25, 0.75]),
),
(
"csr_1",
np.array([0, 1, 1, 0]),
np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]),
ExtensionPolicy(),
np.array([0, 0.5, 0.25, 0.25]),
),
(
"csr_1",
np.array([0, 1, 1, 0]),
np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]),
ExtensionPolicy(collapse_false=True),
np.array([0, 0.25, 0.75]),
),
],
)
def test_prevalence(
self, instances_name, labels, pred_proba, extpol, result, request
):
instances = request.getfixturevalue(instances_name)
ec = ExtendedCollection(
instances=instances,
labels=labels,
pred_proba=pred_proba,
ext=pred_proba,
extpol=extpol,
)
assert (ec.prevalence() == result).all()

View File

@ -1,3 +1,135 @@
import os
from contextlib import redirect_stderr
import numpy as np
import pytest
from quacc.dataset import Dataset
@pytest.mark.dataset
class TestDataset:
pass
@pytest.mark.slow
@pytest.mark.parametrize(
"name,target,prevalence",
[
("spambase", None, [0.5, 0.5]),
("imdb", None, [0.5, 0.5]),
("rcv1", "CCAT", [0.5, 0.5]),
("cifar10", "dog", [0.5, 0.5]),
("twitter_gasp", None, [0.33, 0.33, 0.33]),
],
)
def test__resample_all_train(self, name, target, prevalence, monkeypatch):
def mockinit(self):
self._name = name
self._target = target
self.all_train, self.test = self.alltrain_test(self._name, self._target)
monkeypatch.setattr(Dataset, "__init__", mockinit)
with open(os.devnull, "w") as dn:
with redirect_stderr(dn):
d = Dataset()
d._Dataset__resample_all_train()
assert (
np.around(d.all_train.prevalence(), decimals=2).tolist()
== prevalence
)
@pytest.mark.parametrize(
"ncl, prevs,result",
[
(2, None, None),
(2, [], None),
(2, [[0.2, 0.1], [0.3, 0.2]], None),
(2, [[0.2, 0.8], [0.3, 0.7]], [[0.2, 0.8], [0.3, 0.7]]),
(2, [1.0, 2.0, 3.0], None),
(2, [1, 2, 3], None),
(2, [[1, 2], [2, 3], [3, 4]], None),
(2, ["abc", "def"], None),
(3, [[0.2, 0.3], [0.4, 0.1], [0.5, 0.2]], None),
(3, [[0.2, 0.3, 0.2], [0.4, 0.1], [0.5, 0.6]], None),
(2, [[0.2, 0.3, 0.1], [0.1, 0.5, 0.3]], None),
(3, [[0.2, 0.3, 0.1], [0.1, 0.5, 0.3]], None),
(3, [[0.2, 0.8], [0.1, 0.5]], None),
(2, [[0.2, 0.9], [0.1, 0.5]], None),
(2, 10, None),
(2, [[0.2, 0.8], [0.5, 0.5]], [[0.2, 0.8], [0.5, 0.5]]),
(3, [[0.2, 0.6], [0.3, 0.5]], None),
],
)
def test__check_prevs(self, ncl, prevs, result, monkeypatch):
class MockLabelledCollection:
def __init__(self):
self.n_classes = ncl
def mockinit(self):
self.all_train = MockLabelledCollection()
self.prevs = None
monkeypatch.setattr(Dataset, "__init__", mockinit)
d = Dataset()
d._Dataset__check_prevs(prevs)
_prevs = d.prevs if d.prevs is None else d.prevs.tolist()
assert _prevs == result
# fmt: off
@pytest.mark.parametrize(
"ncl,nprevs,built,result",
[
(2, 3, None, [[0.25, 0.75], [0.5, 0.5], [0.75, 0.25]]),
(2, 3, np.array([[0.8, 0.2], [0.6, 0.4], [0.4, 0.6]]), [[0.8, 0.2], [0.6, 0.4], [0.4, 0.6]]),
(2, 3, np.array([[0.75, 0.25], [0.5, 0.5], [0.25, 0.75]]), [[0.75, 0.25], [0.5, 0.5], [0.25, 0.75]]),
(3, 3, None, [[0.25, 0.25, 0.5], [0.25, 0.5, 0.25], [0.5, 0.25, 0.25]]),
(
3, 4, None,
[[0.2, 0.2, 0.6], [0.2, 0.4, 0.4], [0.2, 0.6, 0.2], [0.4, 0.2, 0.4], [0.4, 0.4, 0.2], [0.6, 0.2, 0.2]],
),
],
)
def test__build_prevs(self, ncl, nprevs, built, result, monkeypatch):
class MockLabelledCollection:
def __init__(self):
self.n_classes = ncl
def mockinit(self):
self.all_train = MockLabelledCollection()
self.prevs = built
self._n_prevs = nprevs
monkeypatch.setattr(Dataset, "__init__", mockinit)
d = Dataset()
_prevs = d._Dataset__build_prevs().tolist()
assert _prevs == result
# fmt: on
@pytest.mark.parametrize(
"ncl,prevs,atsize",
[
(2, np.array([[0.2, 0.8], [0.9, 0.1]]), 55),
(3, np.array([[0.2, 0.7, 0.1], [0.9, 0.05, 0.05]]), 37),
],
)
def test_get(self, ncl, prevs, atsize, monkeypatch):
class MockLabelledCollection:
def __init__(self):
self.n_classes = ncl
def __len__(self):
return 100
def mockinit(self):
self.prevs = prevs
self.all_train = MockLabelledCollection()
def mock_build_sample(self, p, at_size):
return at_size
monkeypatch.setattr(Dataset, "__init__", mockinit)
monkeypatch.setattr(Dataset, "_Dataset__build_sample", mock_build_sample)
d = Dataset()
_get = d.get()
assert all(s == atsize for s in _get)

95
tests/test_error.py Normal file
View File

@ -0,0 +1,95 @@
import numpy as np
import pytest
from quacc import error
from quacc.data import ExtendedPrev, ExtensionPolicy
@pytest.mark.err
class TestError:
@pytest.mark.parametrize(
"prev,result",
[
(np.array([[1, 4], [4, 4]]), 0.5),
(np.array([[6, 2, 4], [2, 4, 2], [4, 2, 6]]), 0.5),
],
)
def test_f1(self, prev, result):
ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy())
assert error.f1(prev) == result
assert error.f1(ep) == result
@pytest.mark.parametrize(
"prev,result",
[
(np.array([[4, 4], [4, 4]]), 0.5),
(np.array([[2, 4, 2], [2, 2, 4], [4, 2, 2]]), 0.25),
],
)
def test_acc(self, prev, result):
ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy())
assert error.acc(prev) == result
assert error.acc(ep) == result
@pytest.mark.parametrize(
"true_prev,estim_prev,nbcl,extpol,result",
[
(
[
np.array([0.2, 0.4, 0.1, 0.3]),
np.array([0.1, 0.5, 0.1, 0.3]),
],
[
np.array([0.3, 0.4, 0.2, 0.1]),
np.array([0.5, 0.3, 0.1, 0.1]),
],
2,
ExtensionPolicy(),
np.array([0.1, 0.2]),
),
(
[
np.array([0.2, 0.4, 0.4]),
np.array([0.1, 0.5, 0.4]),
],
[
np.array([0.3, 0.4, 0.3]),
np.array([0.5, 0.3, 0.2]),
],
2,
ExtensionPolicy(collapse_false=True),
np.array([0.1, 0.2]),
),
(
[
np.array([0.02, 0.04, 0.16, 0.38, 0.1, 0.05, 0.15, 0.08, 0.02]),
np.array([0.04, 0.02, 0.14, 0.40, 0.1, 0.03, 0.17, 0.07, 0.03]),
],
[
np.array([0.02, 0.04, 0.16, 0.48, 0.0, 0.05, 0.15, 0.08, 0.02]),
np.array([0.14, 0.02, 0.04, 0.30, 0.2, 0.03, 0.17, 0.07, 0.03]),
],
3,
ExtensionPolicy(),
np.array([0.1, 0.2]),
),
(
[
np.array([0.2, 0.4, 0.2, 0.2]),
np.array([0.1, 0.3, 0.2, 0.4]),
],
[
np.array([0.3, 0.3, 0.1, 0.3]),
np.array([0.5, 0.2, 0.1, 0.2]),
],
3,
ExtensionPolicy(collapse_false=True),
np.array([0.1, 0.2]),
),
],
)
def test_accd(self, true_prev, estim_prev, nbcl, extpol, result):
true_prev = [ExtendedPrev(tp, nbcl, extpol=extpol) for tp in true_prev]
estim_prev = [ExtendedPrev(ep, nbcl, extpol=extpol) for ep in estim_prev]
_err = error.accd(true_prev, estim_prev)
assert (np.abs(_err - result) < 1e-15).all()

View File

@ -0,0 +1,425 @@
import numpy as np
import pytest
from quacc.evaluation.report import (
CompReport,
DatasetReport,
EvaluationReport,
_get_shift,
)
@pytest.fixture
def empty_er():
return EvaluationReport("empty")
@pytest.fixture
def er_list():
er1 = EvaluationReport("er1")
er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1))
er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4))
er1.append_row(np.array([0.3, 0.7]), **dict(acc=0.7, acc_score=0.3))
er2 = EvaluationReport("er2")
er2.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1))
er2.append_row(
np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6)
)
er2.append_row(np.array([0.4, 0.6]), **dict(acc=0.7, acc_score=0.3))
return [er1, er2]
@pytest.fixture
def er_list2():
er1 = EvaluationReport("er12")
er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1))
er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4))
er1.append_row(np.array([0.3, 0.7]), **dict(acc=0.7, acc_score=0.3))
er2 = EvaluationReport("er2")
er2.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1))
er2.append_row(
np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6)
)
er2.append_row(np.array([0.4, 0.6]), **dict(acc=0.8, acc_score=0.3))
return [er1, er2]
@pytest.fixture
def er_list3():
er1 = EvaluationReport("er31")
er1.append_row(np.array([0.2, 0.5, 0.3]), **dict(acc=0.9, acc_score=0.1))
er1.append_row(np.array([0.2, 0.4, 0.4]), **dict(acc=0.6, acc_score=0.4))
er1.append_row(np.array([0.3, 0.6, 0.1]), **dict(acc=0.7, acc_score=0.3))
er2 = EvaluationReport("er32")
er2.append_row(np.array([0.2, 0.5, 0.3]), **dict(acc=0.9, acc_score=0.1))
er2.append_row(
np.array([0.2, 0.5, 0.3]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6)
)
er2.append_row(np.array([0.3, 0.3, 0.4]), **dict(acc=0.8, acc_score=0.3))
return [er1, er2]
@pytest.fixture
def cr_1(er_list):
return CompReport(
er_list,
"cr_1",
train_prev=np.array([0.2, 0.8]),
valid_prev=np.array([0.25, 0.75]),
g_time=0.0,
)
@pytest.fixture
def cr_2(er_list2):
return CompReport(
er_list2,
"cr_2",
train_prev=np.array([0.3, 0.7]),
valid_prev=np.array([0.35, 0.65]),
g_time=0.0,
)
@pytest.fixture
def cr_3(er_list3):
return CompReport(
er_list3,
"cr_3",
train_prev=np.array([0.4, 0.1, 0.5]),
valid_prev=np.array([0.45, 0.25, 0.2]),
g_time=0.0,
)
@pytest.fixture
def cr_4(er_list3):
return CompReport(
er_list3,
"cr_4",
train_prev=np.array([0.5, 0.1, 0.4]),
valid_prev=np.array([0.45, 0.25, 0.2]),
g_time=0.0,
)
@pytest.fixture
def dr_1(cr_1, cr_2):
return DatasetReport("dr_1", [cr_1, cr_2])
@pytest.fixture
def dr_2(cr_3, cr_4):
return DatasetReport("dr_2", [cr_3, cr_4])
@pytest.mark.rep
@pytest.mark.mrep
class TestReport:
@pytest.mark.parametrize(
"cr_name,train_prev,shift",
[
(
"cr_1",
np.array([0.2, 0.8]),
np.array([0.2, 0.1, 0.0, 0.0]),
),
(
"cr_3",
np.array([0.2, 0.5, 0.3]),
np.array([0.2, 0.2, 0.0, 0.0, 0.1]),
),
],
)
def test_get_shift(self, cr_name, train_prev, shift, request):
cr = request.getfixturevalue(cr_name)
assert (
_get_shift(cr._data.index.get_level_values(0), train_prev) == shift
).all()
@pytest.mark.rep
@pytest.mark.erep
class TestEvaluationReport:
def test_init(self, empty_er):
assert empty_er.data is None
@pytest.mark.parametrize(
"rows,index,columns,data",
[
(
[
(np.array([0.2, 0.8]), dict(acc=0.9, acc_score=0.1)),
(np.array([0.2, 0.8]), dict(acc=0.6, acc_score=0.4)),
(np.array([0.3, 0.7]), dict(acc=0.7, acc_score=0.3)),
],
[((0.2, 0.8), 0), ((0.2, 0.8), 1), ((0.3, 0.7), 0)],
["acc", "acc_score"],
np.array([[0.9, 0.1], [0.6, 0.4], [0.7, 0.3]]),
),
],
)
def test_append_row(self, empty_er, rows, index, columns, data):
er: EvaluationReport = empty_er
for prev, r in rows:
er.append_row(prev, **r)
assert er.data.index.to_list() == index
assert er.data.columns.to_list() == columns
assert (er.data.to_numpy() == data).all()
@pytest.mark.rep
@pytest.mark.crep
class TestCompReport:
@pytest.mark.parametrize(
"train_prev,valid_prev,index,columns",
[
(
np.array([0.2, 0.8]),
np.array([0.25, 0.75]),
[
((0.4, 0.6), 0),
((0.3, 0.7), 0),
((0.2, 0.8), 0),
((0.2, 0.8), 1),
],
[
("acc", "er1"),
("acc", "er2"),
("acc_score", "er1"),
("acc_score", "er2"),
("f1", "er2"),
("f1_score", "er2"),
],
)
],
)
def test_init(self, er_list, train_prev, valid_prev, index, columns):
cr = CompReport(er_list, "cr", train_prev, valid_prev, g_time=0.0)
assert cr._data.index.to_list() == index
assert cr._data.columns.to_list() == columns
assert (cr.train_prev == train_prev).all()
assert (cr.valid_prev == valid_prev).all()
@pytest.mark.parametrize(
"cr_name,prev",
[
("cr_1", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_2", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
(
"cr_3",
[(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)],
),
(
"cr_4",
[(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)],
),
],
)
def test_prevs(self, cr_name, prev, request):
cr = request.getfixturevalue(cr_name)
assert cr.prevs.tolist() == prev
def test_join(self, er_list, er_list2):
tp = np.array([0.2, 0.8])
vp = np.array([0.25, 0.75])
cr1 = CompReport(er_list, "cr1", train_prev=tp, valid_prev=vp)
cr2 = CompReport(er_list2, "cr2", train_prev=tp, valid_prev=vp)
crj = cr1.join(cr2)
_loc = crj._data.loc[((0.4, 0.6), 0), ("acc", "er2")].to_numpy()
assert (_loc == np.array([0.8])).all()
@pytest.mark.parametrize(
"cr_name,metric,estimators,columns",
[
("cr_1", "acc", None, ["er1", "er2"]),
("cr_1", "acc", ["er1"], ["er1"]),
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"]),
("cr_1", "f1", None, ["er2"]),
("cr_1", "f1", ["er2"], ["er2"]),
("cr_3", "acc", None, ["er31", "er32"]),
("cr_3", "acc", ["er31"], ["er31"]),
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"]),
("cr_3", "f1", None, ["er32"]),
("cr_3", "f1", ["er32"], ["er32"]),
],
)
def test_data(self, cr_name, metric, estimators, columns, request):
cr = request.getfixturevalue(cr_name)
_data = cr.data(metric=metric, estimators=estimators)
assert _data.columns.to_list() == columns
assert all(_data.index == cr._data.index)
# fmt: off
@pytest.mark.parametrize(
"cr_name,metric,estimators,columns,index",
[
("cr_1", "acc", None, ["er1", "er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
("cr_1", "acc", ["er1"], ["er1"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
("cr_1", "f1", None, ["er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
("cr_1", "f1", ["er2"], ["er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
("cr_3", "acc", None, ["er31", "er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
("cr_3", "acc", ["er31"], ["er31"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
("cr_3", "f1", None, ["er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
("cr_3", "f1", ["er32"], ["er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
],
)
def test_shift_data(self, cr_name, metric, estimators, columns, index, request):
cr = request.getfixturevalue(cr_name)
_data = cr.shift_data(metric=metric, estimators=estimators)
assert _data.columns.to_list() == columns
assert _data.index.to_list() == index
@pytest.mark.parametrize(
"cr_name,metric,estimators,columns,index",
[
("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
],
)
def test_avg_by_prevs(self, cr_name, metric, estimators, columns, index, request):
cr = request.getfixturevalue(cr_name)
_data = cr.avg_by_prevs(metric=metric, estimators=estimators)
assert _data.columns.to_list() == columns
assert _data.index.to_list() == index
@pytest.mark.parametrize(
"cr_name,metric,estimators,columns,index",
[
("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
],
)
def test_stdev_by_prevs(self, cr_name, metric, estimators, columns, index, request):
cr = request.getfixturevalue(cr_name)
_data = cr.stdev_by_prevs(metric=metric, estimators=estimators)
assert _data.columns.to_list() == columns
assert _data.index.to_list() == index
@pytest.mark.parametrize(
"cr_name,metric,estimators,columns,index",
[
("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
],
)
def test_train_table(self, cr_name, metric, estimators, columns, index, request):
cr = request.getfixturevalue(cr_name)
_data = cr.train_table(metric=metric, estimators=estimators)
assert _data.columns.to_list() == columns
assert _data.index.to_list() == index
@pytest.mark.parametrize(
"cr_name,metric,estimators,columns,index",
[
("cr_1", "acc", None, ["er1", "er2"], [0.0, 0.1, 0.2, "mean"]),
("cr_1", "acc", ["er1"], ["er1"], [0.0, 0.1, 0.2, "mean"]),
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [0.0, 0.1, 0.2, "mean"]),
("cr_1", "f1", None, ["er2"], [0.0, 0.1, 0.2, "mean"]),
("cr_1", "f1", ["er2"], ["er2"], [0.0, 0.1, 0.2, "mean"]),
("cr_3", "acc", None, ["er31", "er32"], [0.2, 0.3, 0.4, 0.5, "mean"]),
("cr_3", "acc", ["er31"], ["er31"], [0.2, 0.3, 0.4, 0.5, "mean"]),
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [0.2, 0.3, 0.4, 0.5, "mean"]),
("cr_3", "f1", None, ["er32"], [0.2, 0.3, 0.4, 0.5, "mean"]),
("cr_3", "f1", ["er32"], ["er32"], [0.2, 0.3, 0.4, 0.5, "mean"]),
],
)
def test_shift_table(self, cr_name, metric, estimators, columns, index, request):
cr = request.getfixturevalue(cr_name)
_data = cr.shift_table(metric=metric, estimators=estimators)
assert _data.columns.to_list() == columns
assert _data.index.to_list() == index
# fmt: on
@pytest.mark.rep
@pytest.mark.drep
class TestDatasetReport:
# fmt: off
@pytest.mark.parametrize(
"dr_name,metric,estimators,columns,index",
[
(
"dr_1", "acc", None, ["er1", "er2", "er12"],
[
((0.3, 0.7), (0.4, 0.6), 0),
((0.3, 0.7), (0.3, 0.7), 0),
((0.3, 0.7), (0.2, 0.8), 0),
((0.3, 0.7), (0.2, 0.8), 1),
((0.2, 0.8), (0.4, 0.6), 0),
((0.2, 0.8), (0.3, 0.7), 0),
((0.2, 0.8), (0.2, 0.8), 0),
((0.2, 0.8), (0.2, 0.8), 1),
],
),
(
"dr_2", "acc", None, ["er31", "er32"],
[
((0.5, 0.1, 0.4), (0.3, 0.6, 0.1), 0),
((0.5, 0.1, 0.4), (0.3, 0.3, 0.4), 0),
((0.5, 0.1, 0.4), (0.2, 0.5, 0.3), 0),
((0.5, 0.1, 0.4), (0.2, 0.5, 0.3), 1),
((0.5, 0.1, 0.4), (0.2, 0.4, 0.4), 0),
((0.4, 0.1, 0.5), (0.3, 0.6, 0.1), 0),
((0.4, 0.1, 0.5), (0.3, 0.3, 0.4), 0),
((0.4, 0.1, 0.5), (0.2, 0.5, 0.3), 0),
((0.4, 0.1, 0.5), (0.2, 0.5, 0.3), 1),
((0.4, 0.1, 0.5), (0.2, 0.4, 0.4), 0),
],
),
],
)
def test_data(self, dr_name, metric, estimators, columns, index, request):
dr = request.getfixturevalue(dr_name)
_data = dr.data(metric=metric, estimators=estimators)
assert _data.columns.to_list() == columns
assert _data.index.to_list() == index
@pytest.mark.parametrize(
"dr_name,metric,estimators,columns,index",
[
(
"dr_1", "acc", None, ["er1", "er2", "er12"],
[(0.0, 0), (0.0, 1), (0.0, 2), (0.1, 0),
(0.1, 1), (0.1, 2), (0.1, 3), (0.2, 0)],
),
(
"dr_2", "acc", None, ["er31", "er32"],
[(0.2, 0), (0.2, 1), (0.3, 0), (0.3, 1), (0.4, 0),
(0.4, 1), (0.4, 2), (0.4, 3), (0.5, 0), (0.5, 1)],
),
],
)
def test_shift_data(self, dr_name, metric, estimators, columns, index, request):
dr = request.getfixturevalue(dr_name)
_data = dr.shift_data(metric=metric, estimators=estimators)
print(_data.index.tolist())
assert _data.columns.to_list() == columns
assert _data.index.to_list() == index
# fmt: off

View File

@ -0,0 +1,100 @@
import numpy as np
import pytest
import scipy.sparse as sp
from quacc.data import ExtendedData, ExtensionPolicy
from quacc.method.base import MultiClassAccuracyEstimator
@pytest.mark.mcae
class TestMultiClassAccuracyEstimator:
@pytest.mark.parametrize(
"instances,pred_proba,extpol,result",
[
(
np.arange(12).reshape((4, 3)),
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
ExtensionPolicy(),
np.array([0.21, 0.39, 0.1, 0.4]),
),
(
np.arange(12).reshape((4, 3)),
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
ExtensionPolicy(collapse_false=True),
np.array([0.21, 0.39, 0.5]),
),
(
sp.csr_matrix(np.arange(12).reshape((4, 3))),
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
ExtensionPolicy(),
np.array([0.21, 0.39, 0.1, 0.4]),
),
(
np.arange(12).reshape((4, 3)),
np.array(
[
[0.3, 0.2, 0.5],
[0.13, 0.67, 0.2],
[0.21, 0.09, 0.8],
[0.19, 0.1, 0.71],
]
),
ExtensionPolicy(),
np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]),
),
(
np.arange(12).reshape((4, 3)),
np.array(
[
[0.3, 0.2, 0.5],
[0.13, 0.67, 0.2],
[0.21, 0.09, 0.8],
[0.19, 0.1, 0.71],
]
),
ExtensionPolicy(collapse_false=True),
np.array([0.21, 0.09, 0.1, 0.7]),
),
(
sp.csr_matrix(np.arange(12).reshape((4, 3))),
np.array(
[
[0.3, 0.2, 0.5],
[0.13, 0.67, 0.2],
[0.21, 0.09, 0.8],
[0.19, 0.1, 0.71],
]
),
ExtensionPolicy(),
np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]),
),
],
)
def test_estimate(self, monkeypatch, instances, pred_proba, extpol, result):
ed = ExtendedData(instances, pred_proba, pred_proba, extpol)
class MockQuantifier:
def __init__(self):
self.classes_ = np.arange(result.shape[0])
def quantify(self, X):
return result
def mockinit(self):
self.extpol = extpol
self.quantifier = MockQuantifier()
def mock_extend_instances(self, instances):
return ed
monkeypatch.setattr(MultiClassAccuracyEstimator, "__init__", mockinit)
monkeypatch.setattr(
MultiClassAccuracyEstimator, "_extend_instances", mock_extend_instances
)
mcae = MultiClassAccuracyEstimator()
ep1 = mcae.estimate(instances)
ep2 = mcae.estimate(ed)
assert (ep1.flat == ep2.flat).all()
assert (ep1.flat == result).all()

View File

@ -1,66 +0,0 @@
import numpy as np
import pytest
import scipy.sparse as sp
from sklearn.linear_model import LogisticRegression
from quacc.method.base import BinaryQuantifierAccuracyEstimator
class TestBQAE:
@pytest.mark.parametrize(
"instances,preds0,preds1,result",
[
(
np.asarray(
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
),
np.asarray([0.3, 0.7]),
np.asarray([0.4, 0.6]),
np.asarray([0.15, 0.2, 0.35, 0.3]),
),
(
sp.csr_matrix(
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
),
np.asarray([0.3, 0.7]),
np.asarray([0.4, 0.6]),
np.asarray([0.15, 0.2, 0.35, 0.3]),
),
(
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
np.asarray([0.3, 0.7]),
np.asarray([0.4, 0.6]),
np.asarray([0.0, 0.4, 0.0, 0.6]),
),
(
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
np.asarray([0.3, 0.7]),
np.asarray([0.4, 0.6]),
np.asarray([0.0, 0.4, 0.0, 0.6]),
),
(
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
np.asarray([0.3, 0.7]),
np.asarray([0.4, 0.6]),
np.asarray([0.3, 0.0, 0.7, 0.0]),
),
(
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
np.asarray([0.3, 0.7]),
np.asarray([0.4, 0.6]),
np.asarray([0.3, 0.0, 0.7, 0.0]),
),
],
)
def test_estimate_ndarray(self, mocker, instances, preds0, preds1, result):
estimator = BinaryQuantifierAccuracyEstimator(LogisticRegression())
estimator.n_classes = 4
with mocker.patch.object(estimator.q_model_0, "quantify"), mocker.patch.object(
estimator.q_model_1, "quantify"
):
estimator.q_model_0.quantify.return_value = preds0
estimator.q_model_1.quantify.return_value = preds1
assert np.array_equal(
estimator.estimate(instances, ext=True),
result,
)

View File

@ -1,2 +0,0 @@
class TestMCAE:
pass