kfcv baseline implemented

This commit is contained in:
Lorenzo Volpi 2023-09-13 00:11:20 +02:00
parent 1297902895
commit ce04bf9420
6 changed files with 35 additions and 35 deletions

BIN
.coverage

Binary file not shown.

14
quacc/baseline.py Normal file
View File

@ -0,0 +1,14 @@
from statistics import mean
from typing import Dict
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_validate
from quapy.data import LabelledCollection
def kfcv(c_model: BaseEstimator, train: LabelledCollection) -> Dict:
scoring = ["f1_macro"]
scores = cross_validate(c_model, train.X, train.y, scoring=scoring)
return {
"f1_score": mean(scores["test_f1_macro"])
}

View File

@ -1,4 +1,7 @@
import quapy as qp
def getImdbTrainTest():
return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test
def get_imdb_traintest():
return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test
def get_spambase_traintest():
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test

View File

@ -9,7 +9,7 @@ from quacc.estimator import (
MulticlassAccuracyEstimator,
)
from quacc.dataset import getImdbTrainTest
from quacc.dataset import get_imdb_traintest
qp.environ["SAMPLE_SIZE"] = 100
@ -20,7 +20,7 @@ dataset_name = "imdb"
def estimate_multiclass():
print(dataset_name)
train, test = getImdbTrainTest(dataset_name)
train, test = get_imdb_traintest(dataset_name)
model = LogisticRegression()
@ -59,7 +59,7 @@ def estimate_multiclass():
def estimate_binary():
print(dataset_name)
train, test = getImdbTrainTest(dataset_name)
train, test = get_imdb_traintest(dataset_name)
model = LogisticRegression()

12
tests/test_baseline.py Normal file
View File

@ -0,0 +1,12 @@
from sklearn.linear_model import LogisticRegression
from quacc.baseline import kfcv
from quacc.dataset import get_spambase_traintest
class TestBaseline:
def test_kfcv(self):
train, _ = get_spambase_traintest()
c_model = LogisticRegression()
assert "f1_score" in kfcv(c_model, train)

View File

@ -1,32 +1,3 @@
import pytest
from quacc.dataset import Rcv1Helper
@pytest.fixture
def rcv1_helper() -> Rcv1Helper:
return Rcv1Helper()
class TestDataset:
def test_rcv1_binary_datasets(self, rcv1_helper):
count = 0
for X, Y, name in rcv1_helper.rcv1_binary_datasets():
count += 1
print(X.shape)
assert X.shape == (517978, 47236)
assert Y.shape == (517978,)
assert count == 37
@pytest.mark.parametrize("label", ["CCAT", "GCAT", "M11"])
def test_rcv1_binary_dataset_by_label(self, rcv1_helper, label):
train, test = rcv1_helper.rcv1_binary_dataset_by_label(label)
assert train.X.shape == (23149, 47236)
assert train.y.shape == (23149,)
assert test.X.shape == (781265, 47236)
assert test.y.shape == (781265,)
assert (
dict(rcv1_helper.documents_per_class_rcv1())[label]
== train.y.sum() + test.y.sum()
)
pass