2023-09-13 00:11:20 +02:00
|
|
|
|
|
|
|
from sklearn.linear_model import LogisticRegression
|
2023-09-16 01:59:49 +02:00
|
|
|
from quacc.baseline import kfcv, trust_score
|
2023-09-14 01:52:19 +02:00
|
|
|
from quacc.dataset import get_spambase
|
2023-09-13 00:11:20 +02:00
|
|
|
|
|
|
|
|
|
|
|
class TestBaseline:
|
|
|
|
|
|
|
|
def test_kfcv(self):
|
2023-09-16 01:59:49 +02:00
|
|
|
train, validation, _ = get_spambase()
|
2023-09-13 00:11:20 +02:00
|
|
|
c_model = LogisticRegression()
|
2023-09-16 01:59:49 +02:00
|
|
|
c_model.fit(train.X, train.y)
|
|
|
|
assert "f1_score" in kfcv(c_model, validation)
|
|
|
|
|
|
|
|
def test_trust_score(self):
|
|
|
|
train, validation, test = get_spambase()
|
|
|
|
c_model = LogisticRegression()
|
|
|
|
c_model.fit(train.X, train.y)
|
|
|
|
trustscore = trust_score(c_model, train, test)
|
|
|
|
assert len(trustscore) == len(test.y)
|