From ce04bf9420a9676adb362f403f0ebd0cf839b604 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Wed, 13 Sep 2023 00:11:20 +0200 Subject: [PATCH] kfcv baseline implemented --- .coverage | Bin 53248 -> 53248 bytes quacc/baseline.py | 14 ++++++++++++++ quacc/dataset.py | 7 +++++-- quacc/main.py | 6 +++--- tests/test_baseline.py | 12 ++++++++++++ tests/test_dataset.py | 31 +------------------------------ 6 files changed, 35 insertions(+), 35 deletions(-) create mode 100644 quacc/baseline.py create mode 100644 tests/test_baseline.py diff --git a/.coverage b/.coverage index 9f20b3c767b3278133b1ef59663705c543bce9fe..7ce4189d0f7865664c1e56671a5955c6966ec45e 100644 GIT binary patch delta 410 zcmYk&KQDt(6bA5n-}|1n7rjpnY7H0NHrCSegYlt2VgT; z(5bV55J5WBA{7G$l=zc~fmQ?QlWi_mP zx!S;^pCVGQq+WAcb;7CPKE3MnaIfBV`ngBnJ7JpyOX;%Hrd_w1x^2Pcb=>XHqwWAj ztly`B$%&=ux!Hy3&c$?u|J{3YVPwg0h%A~hs`k{KN~xHVMlSnxY`m9bkC^914;gu9 zyq~;bk>G~wGI~?IkLYLr1I7pKUXiSnTmO4jL?mS;%)Q0Ys0ew64?N)xXUH393uAR; zbJYk6vc0HzqCow?^Y>-i`2~)5y5KJDedB9x+|!m?NBgAOlqwzeFQNN(7u6% zynxU_!NE=HAVr4?{s9FQ#KkWWMW@R-_i)ay(ZWUx+e}ZWJO^}2N3=poe$S8jKF|3> z@f4S{s&`nJjN6-~b<3=5O>Ng_q6n5ZXrw&bP+0__9*E)cuDMw?-HIJ#G75qnW_pIk z7(y)&LuJRUSf*?5h{zlwIDtzg(>0y48&u0b)rCS-&Pu_w_f{ammj^l{k5>5ujq^4? z;CaE9=AO1Tuc#Tcq;SnPqDpiT>!JFO6q!Z)JJd5&MOnaFsQxjrEIOh4hY5en1(PCY zv8n}A{@Xew2N-pLKIn$7s3UlMT23$~X=(O;K3!Bn);Hpy1LzSHF2~r2Ebf%N4~?1P z5&%2{5bVL2A~8+lQcPHk`aqxbN>9`mm6zo8nu`C>98D%v>IScOdZ9p`bzyJUM diff --git a/quacc/baseline.py b/quacc/baseline.py new file mode 100644 index 0000000..e456fcb --- /dev/null +++ b/quacc/baseline.py @@ -0,0 +1,14 @@ + +from statistics import mean +from typing import Dict +from sklearn.base import BaseEstimator +from sklearn.model_selection import cross_validate +from quapy.data import LabelledCollection + + +def kfcv(c_model: BaseEstimator, train: LabelledCollection) -> Dict: + scoring = ["f1_macro"] + scores = cross_validate(c_model, train.X, train.y, scoring=scoring) + return { + "f1_score": mean(scores["test_f1_macro"]) + } diff --git a/quacc/dataset.py b/quacc/dataset.py index eed7384..36d1485 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -1,4 +1,7 @@ import quapy as qp -def getImdbTrainTest(): - return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test \ No newline at end of file +def get_imdb_traintest(): + return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test + +def get_spambase_traintest(): + return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test \ No newline at end of file diff --git a/quacc/main.py b/quacc/main.py index d28b9ae..ca63d46 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -9,7 +9,7 @@ from quacc.estimator import ( MulticlassAccuracyEstimator, ) -from quacc.dataset import getImdbTrainTest +from quacc.dataset import get_imdb_traintest qp.environ["SAMPLE_SIZE"] = 100 @@ -20,7 +20,7 @@ dataset_name = "imdb" def estimate_multiclass(): print(dataset_name) - train, test = getImdbTrainTest(dataset_name) + train, test = get_imdb_traintest(dataset_name) model = LogisticRegression() @@ -59,7 +59,7 @@ def estimate_multiclass(): def estimate_binary(): print(dataset_name) - train, test = getImdbTrainTest(dataset_name) + train, test = get_imdb_traintest(dataset_name) model = LogisticRegression() diff --git a/tests/test_baseline.py b/tests/test_baseline.py new file mode 100644 index 0000000..82b0218 --- /dev/null +++ b/tests/test_baseline.py @@ -0,0 +1,12 @@ + +from sklearn.linear_model import LogisticRegression +from quacc.baseline import kfcv +from quacc.dataset import get_spambase_traintest + + +class TestBaseline: + + def test_kfcv(self): + train, _ = get_spambase_traintest() + c_model = LogisticRegression() + assert "f1_score" in kfcv(c_model, train) \ No newline at end of file diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4a77368..b3ffda5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,32 +1,3 @@ -import pytest -from quacc.dataset import Rcv1Helper - - -@pytest.fixture -def rcv1_helper() -> Rcv1Helper: - return Rcv1Helper() - class TestDataset: - def test_rcv1_binary_datasets(self, rcv1_helper): - count = 0 - for X, Y, name in rcv1_helper.rcv1_binary_datasets(): - count += 1 - print(X.shape) - assert X.shape == (517978, 47236) - assert Y.shape == (517978,) - - assert count == 37 - - @pytest.mark.parametrize("label", ["CCAT", "GCAT", "M11"]) - def test_rcv1_binary_dataset_by_label(self, rcv1_helper, label): - train, test = rcv1_helper.rcv1_binary_dataset_by_label(label) - assert train.X.shape == (23149, 47236) - assert train.y.shape == (23149,) - assert test.X.shape == (781265, 47236) - assert test.y.shape == (781265,) - - assert ( - dict(rcv1_helper.documents_per_class_rcv1())[label] - == train.y.sum() + test.y.sum() - ) + pass \ No newline at end of file