From c01ac0915cf092cd7b8f809a25a03d544c75932d Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Tue, 12 Sep 2023 17:38:49 +0200 Subject: [PATCH] Dataset updated, evaluation updated, tests updated --- .coverage | Bin 53248 -> 53248 bytes quacc/data.py | 15 ------ quacc/dataset.py | 4 ++ quacc/evaluation.py | 2 +- rcv1_hierarchy | 104 ++++++++++++++++++++++++++++++++++++++++++ tests/test_dataset.py | 32 +++++++++++++ 6 files changed, 141 insertions(+), 16 deletions(-) create mode 100644 quacc/dataset.py create mode 100644 rcv1_hierarchy create mode 100644 tests/test_dataset.py diff --git a/.coverage b/.coverage index fc13d34ba7bf763d12fd54a9ce2ef06ddec35d23..9f20b3c767b3278133b1ef59663705c543bce9fe 100644 GIT binary patch delta 410 zcmY+7!Ab&A6o$`u?u;{1_h?A%!f2%n39__SqnlPat_vy=7eY4CXr@B=214@y`34E| z0)mTz7VW|og0(QHgn|&FP1muFtHXcJ_n)uP!bS_*Ob@C&33N+Wv_~QS#&39yC!K&l z%sp|{uBQ{si%_&L(_wp%R+)lOorHz_dG@%REfvg@=xH@3BbY9eF*L>yDw7b%TctuV zTQbiCvWWyv^%8CFBuq*A5c08pERGb_Qv$?A^wy zn`Yh02AWeSW2mycg%vl=$XOh%Df2IPnuq1k9HTDK4?WWpb%oWT-5_I<7GdN2wVVnb z{m2hGfUcm}a`Y{F#Eg1s&{)sy0Kg{z!38YK64Nv;`NWFR5a^e_X+%Tu@{zoX=tzt? hqwA$mP^I4FJf<)DpjUdKKJ^6FAsJPKdAt9x`UgM{awPx& delta 271 zcmWm6!Ab&A7{&4L`ps<`i+fChT7~5*R~AYoEry`&oo#J2FlZHm5TOLk6SSE70Qv~l zEZb+&q9+)PHnwo6AcShuZTml*U)blc&jZKLyQ)JNBEk_i)I{B>PP~+BD!Jn_U6^aV zX Rcv1Helper: + return Rcv1Helper() + + +class TestDataset: + def test_rcv1_binary_datasets(self, rcv1_helper): + count = 0 + for X, Y, name in rcv1_helper.rcv1_binary_datasets(): + count += 1 + print(X.shape) + assert X.shape == (517978, 47236) + assert Y.shape == (517978,) + + assert count == 37 + + @pytest.mark.parametrize("label", ["CCAT", "GCAT", "M11"]) + def test_rcv1_binary_dataset_by_label(self, rcv1_helper, label): + train, test = rcv1_helper.rcv1_binary_dataset_by_label(label) + assert train.X.shape == (23149, 47236) + assert train.y.shape == (23149,) + assert test.X.shape == (781265, 47236) + assert test.y.shape == (781265,) + + assert ( + dict(rcv1_helper.documents_per_class_rcv1())[label] + == train.y.sum() + test.y.sum() + )