2023-07-27 03:16:41 +02:00
|
|
|
import pytest
|
|
|
|
import numpy as np
|
|
|
|
import scipy.sparse as sp
|
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
|
|
|
|
|
from quacc.estimator import BinaryQuantifierAccuracyEstimator
|
2023-07-26 00:38:23 +02:00
|
|
|
|
2023-07-27 03:16:41 +02:00
|
|
|
|
|
|
|
class TestBinaryQuantifierAccuracyEstimator:
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"instances,preds0,preds1,result",
|
|
|
|
[
|
|
|
|
(
|
|
|
|
np.asarray(
|
|
|
|
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
|
|
),
|
|
|
|
np.asarray([0.3, 0.7]),
|
|
|
|
np.asarray([0.4, 0.6]),
|
|
|
|
np.asarray([0.15, 0.2, 0.35, 0.3]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix(
|
|
|
|
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
|
|
|
),
|
|
|
|
np.asarray([0.3, 0.7]),
|
|
|
|
np.asarray([0.4, 0.6]),
|
|
|
|
np.asarray([0.15, 0.2, 0.35, 0.3]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
np.asarray([0.3, 0.7]),
|
|
|
|
np.asarray([0.4, 0.6]),
|
|
|
|
np.asarray([0.0, 0.4, 0.0, 0.6]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
|
|
|
np.asarray([0.3, 0.7]),
|
|
|
|
np.asarray([0.4, 0.6]),
|
|
|
|
np.asarray([0.0, 0.4, 0.0, 0.6]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([0.3, 0.7]),
|
|
|
|
np.asarray([0.4, 0.6]),
|
|
|
|
np.asarray([0.3, 0.0, 0.7, 0.0]),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
|
|
|
np.asarray([0.3, 0.7]),
|
|
|
|
np.asarray([0.4, 0.6]),
|
|
|
|
np.asarray([0.3, 0.0, 0.7, 0.0]),
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_estimate_ndarray(self, mocker, instances, preds0, preds1, result):
|
|
|
|
estimator = BinaryQuantifierAccuracyEstimator(LogisticRegression())
|
|
|
|
estimator.n_classes = 4
|
|
|
|
with mocker.patch.object(estimator.q_model_0, "quantify"), mocker.patch.object(
|
|
|
|
estimator.q_model_1, "quantify"
|
|
|
|
):
|
|
|
|
estimator.q_model_0.quantify.return_value = preds0
|
|
|
|
estimator.q_model_1.quantify.return_value = preds1
|
|
|
|
assert np.array_equal(
|
|
|
|
estimator.estimate(instances, ext=True),
|
|
|
|
result,
|
|
|
|
)
|