QuAcc/tests/test_baseline.py

20 lines
622 B
Python
Raw Normal View History

2023-09-13 00:11:20 +02:00
from sklearn.linear_model import LogisticRegression
2023-09-16 01:59:49 +02:00
from quacc.baseline import kfcv, trust_score
2023-09-14 01:52:19 +02:00
from quacc.dataset import get_spambase
2023-09-13 00:11:20 +02:00
class TestBaseline:
def test_kfcv(self):
2023-09-16 01:59:49 +02:00
train, validation, _ = get_spambase()
2023-09-13 00:11:20 +02:00
c_model = LogisticRegression()
2023-09-16 01:59:49 +02:00
c_model.fit(train.X, train.y)
assert "f1_score" in kfcv(c_model, validation)
def test_trust_score(self):
train, validation, test = get_spambase()
c_model = LogisticRegression()
c_model.fit(train.X, train.y)
trustscore = trust_score(c_model, train, test)
assert len(trustscore) == len(test.y)