From 7720b0d20e85586da04305d8847830aed1494e26 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Tue, 26 Sep 2023 07:58:40 +0200 Subject: [PATCH] baseline performance test updated --- elsahar19_rca/__pycache__/rca.cpython-311.pyc | Bin 966 -> 0 bytes .../__pycache__/ATC_helper.cpython-311.pyc | Bin 1816 -> 0 bytes .../__pycache__/doc.cpython-311.pyc | Bin 429 -> 0 bytes .../__pycache__/trustscore.cpython-311.pyc | Bin 7636 -> 0 bytes .../__pycache__/labelshift.cpython-311.pyc | Bin 4410 -> 0 bytes quacc/baseline.py | 8 +-- quacc/dataset.py | 42 +++++++---- quacc/error.py | 28 +++++--- quacc/evaluation.py | 66 ++++++++++-------- quacc/main.py | 33 ++++----- quacc/utils.py | 31 ++++++++ 11 files changed, 136 insertions(+), 72 deletions(-) delete mode 100644 elsahar19_rca/__pycache__/rca.cpython-311.pyc delete mode 100644 garg22_ATC/__pycache__/ATC_helper.cpython-311.pyc delete mode 100644 guillory21_doc/__pycache__/doc.cpython-311.pyc delete mode 100644 jiang18_trustscore/__pycache__/trustscore.cpython-311.pyc delete mode 100644 lipton_bbse/__pycache__/labelshift.cpython-311.pyc create mode 100644 quacc/utils.py diff --git a/elsahar19_rca/__pycache__/rca.cpython-311.pyc b/elsahar19_rca/__pycache__/rca.cpython-311.pyc deleted file mode 100644 index aeb753d2e367711efe91b1b9be946dfd95cda37e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 966 zcmbVK&1(}u6rb6hP0}V&F+mU^cq%M~+BOG2@n8rDElpyuLRbjf>{Menn=rExBHI$A zAVLqpOFiaPDMb)3`WJW;k`ibb@Z`x`Ah(`;vrPgko_xG{AH#cZ-uun_7>~~*ps(L! z)WQh;lFpQri7~zd#v!thg>6*ApcPxGsU;Q5JJlvNt)ziXsJ5+G+CC|Ttk6Cxg)Lk} zdib|2v!P?5RqVP$!H7RHIn8pnS~J}`6X8m^R=22KQ6~>dMc_z(oCk4;Txf1e9y$W= zWOlQwdTMv(O{1jtu%{fs7$?(F5JsL_M8|kTSC}j>)a|OP6Q)UJn~TV# ztyM02xe0WBD|63y$|>W9U1!uWc)iIg)Nm=U8jxbXFxkr8GOL&doTWxf%=sgi;SCmt zA|W_@0Yn@9yqxSjd$)J6*M2-CvA64Q);r10t3F8$NUBFt!)T(N8$%=tpdt(SsC2;y zqVjTvfQLL~$Dg=6n;XJsms+`yc?lX#C)IC;MUZeeNu= zc$!%Jl<6my2Z`m+*>9U)a%bta)AZVpOh26;r1Skmz8@_NqJs4rVrT-2$JY?2gFFl7!T3a k-js*vdN0UBw9pIkD5TYE^nmFSr&B;YL^6`Gik7j={DhVnnFY+s9AHl70rN7xjUvJc zRa#2nt}cDK;Q`y|^hZDsQSr>T@A~urw|6r#2BoPL2Mq9K6R5|a|Fm!8G3Y(*>&nbC zc8~-9aZ8zSbldau?%xY1RBs>KTyj0DM>w0&3O1{ka<-T&sl{wT-LP3Tp8*NJK=LcP zPUVBouid;J|3+7c9#`{3(c*f(L{dt;sOZ^vCP^|gGl|dd+>FDMSXI=5Lar7zY$>A@ z6N*+O`N9Ud1a@d7`Uik2=ujN02@jYZ$z1$(=9k$glfO^yzt;@US>d_n@VqrV|7WNn z#+qWx5@QV-f)29`i2-T>4Mp;|I$8sd0U9Jj>r1|BCq_?GII0}k{O$59hfd9kB| zMJ+Q1GMG0(wF70=zyoLQg4V$@UUP0gXOEYevM`|YqOEvC*o1*Nr{Mw(u}Isv3q_ip zE*N-FdhirO-`*~}4bE`Ch@j^e5e!~tHz9>(UW1!~hE%_rZ-#5l@qS{fDiEorR`%!tzeE zk;qRvsES+C6?4Rr#?0>zq+mk|wmg^Wq5EG}qt)n9k35#Vwb}aHb+R>m|3>YGdG|n? zYDiOuR|30&N%o?T?mhDD`l`2D{?J~?^1oAE{P#%m*Oi7e*5Jloo~i!u&q2T!&J;G# z8lAp#oefJAxKMLusax*5pY^r?ms64&973Q)X!Gi{D}upJ?wy53Q=d$w2>pMcV)_^$ zNIHc0QGM3@aQl;{6tbky-egmnZgA6%_YuM7wNeiDAYC@jGIZ97Q#!}gK5_(|i=K-g z=kjt%RX!yV5T`3xUk21>F~%)4+UWh-9L9^djm}^JV<=A_39sR|+Nh7-$M{|P^*nO! F{vWiAcdGyZ diff --git a/guillory21_doc/__pycache__/doc.cpython-311.pyc b/guillory21_doc/__pycache__/doc.cpython-311.pyc deleted file mode 100644 index ea09f92757f3ecd6563b5e3040a66b69d5f78501..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 429 zcmZ3^%ge<81ly1Cq$UFC#~=<2FhLogMSzUy3@HpLj5!QZj44bl3{gxejKK_=%r8Mo z{E}fRK|B^9W(MNVAwXg}LkT00f`TkKyM}QYki8nt$DxV=RaF(xtw38D{WO_wG36E9 zV#!TS%mcF73X1ZRiVdN(Q8CCl3JSjroULL)i&Kk=V{-C~QuAVp^Gl18Q)5a}i!)== zOEYtTqLoI5@hSPqF+i+WP2KczG$)vkya i$OZYjSPMveU}j`wyul!J0fs)XFf!_XV8A56W&r@-30dp_ diff --git a/jiang18_trustscore/__pycache__/trustscore.cpython-311.pyc b/jiang18_trustscore/__pycache__/trustscore.cpython-311.pyc deleted file mode 100644 index e787657323176157ab9d54e060a83e6186006ced..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7636 zcmd5>U2GfIm7d{nD3bc2D2lQr+3ZNRWHXjj%XJ)Ca@>_I+i~J3wd!2vPi45hKi zp*%B;ErMD&fCDgvfT|01%>XJuEZkH^R6q|YR#@~w*oQ^XCsG)|!~p6`5d_YY3}+YE zzO?7uA%`=v)ovFpuruPBJ9qBSx%Zy)oge+Ox!F&_^>6>#B>qQ;qJEDrwaZ;Cytn~{ zN0dMbbds8);ps}cQp^;Sa!;JDpb4i%DKprGz9Z@#=WwVezyIua78^@H6ES=qW}B z2%d-Zlv`*J**l(Er-V!JK6FntIBNx;qm5heLmTg3XwwL7e0$n7?HQ*T+W3VQSY4~( zzo}$Y?T7KSB4%&;Ik-p#M`}hno=b_mno&e9o#5iua=DBg7Zr_{WNm?)=2cPP(lRHD zydtU^JgfUN1Oq<9v-`TXQo}8$W%29aH-jI2`snVPaklmWo_Ht>{9c>~@`%dOIVuJv z^U5*sMq6(#c(V>M=Y}`zNiK41uIx8tF}R>Pj+X_FPtMQs95Bdfv+&Evl03uB^9r95 zHBnKyL|P$V5>gW0;p2ccfK@dLyP?gdVX+C$svP3d^O^+9$FX~HdA$Q%uOP~*q%HK; zN9ncgS+6yAN|IxOq-vv_nofzFn&Oj5PR^vJMPMh*r6f(7;Wd$)PlKdQCq+(>K9p3P zZhC<`i5tvulTo)p&lx_8C4(iaHN3Uaqpaap#biQhf-O>7kTfF+08#M_a2W0nV;4tn zzpuh>+)jc#$hXyWMv05JHBpsr-;sEEW_To~p}eRlFSqN(q4@>F8;eP@q{U*{P+g{m zYLy3ZrlUX}P`fQuTPW{d@>Du|;n7kFcfrG7Imp4IIS=;>Xh#|Jt+;i>@Q!Q}@$iBe zJfbv97$^{x&*1HIt7g%aa}}*O9ISSsP~EXcU8kKoFSt5u4FN3#=3|%OPT1-e8Lh!; zK;5b7IhA7s7RI%NISVo`c(U;B9DmWBqa7ny2ms>ko+BJPSMoAPc770_vifBrv0OT% zfs}wos@y$Eo8^+}d)(aQBv+NhAwQ9oab-pwC8b-S-6*sQzku?r@~L?cFLQ)y;~S!u zQRHg-aZb%ltD=UDQH9K&IF)T?NhaT>)0&n}nL0O^Q201ekaU7foK~b6N#>I^ZectJ zpD2n?Xj4@2s3xEN5w2&rJI@3l(FO(Oz+mrYL}h`rGhFh#;hzT_7G+IJibTW>_TH?h zh*4H)h2}~-_VLUK8iKfQPL^W|FG!NXi1TV)2b3T-3!17Fn_h@jw=a9hCe$x|MZs2c zKps&4)lRh?EYLsgTOboDOI?-Mz2k*zx||T^GNyFo8^}CddvAdQwg;GtZn6BAzD2g)eUZsAMN3a$9P0uN zayg}%2Kg`A=d`umGpbwjSs!isf&gB(hJx!An!`GJT3QG_**T(j&Y1+JuE~oY;MpU% z1^34;yY}QGCz;?4}cvc>mcAoUX{D zn}6c_v2VF&`T8e;r9eIa=IUwq_;&dCTCf~Gp@&c8{X4pBd2>Cj<;&O$f4qa?Uy9QF_O%*L!#+bpB&(Z2Zn}*1~@rpHXR&}{*_)7qSL}C2i`IS zM)PS|K<3QB%t;jVizkgn8)7LaXarS_hMM#7zG;$xKP?SGK_))SLAa8VASQqT#rwo} zgBV6Zr^}q6hx!iN1WVe>Snj_Ss{yv*U#OQ^vV6mPE2gDm8C84<$dx`A zJKN%Lx3meUQzZZd0O<~@vwPK3xcjMZ)mI$Bs@%E3d+9~Rw@o683V;k~KS7+5-A z_+Y#7wNm42mA>fHzVC1MeSeGI8ZP%;()%uzn|kwa7yLWnuEOcnH_PFu9**u(0bkp5 zB>78~KxpZoiXGd5W2L~cO823k-&nl`82$Rc-YFj*)(;PF^yr6A>)o#}dw2Q=)^3&i z&*=SUmfb5&h4;&?M@p?nD#7qdqRJyzZcEa+>Mws5D zX94WDC^(D}La}L#;5_y-MgpCq>w%U%jIo1mVvj2c0hSGB?Jb=UF+qJW2EmlWg{*^q z(F&`8Q|rpn0<*p+{bT1#pamKNKGxwPGG=WK!ay_+OXv21R7UWuH#ps5Fl^4%Ft6SZ z=gM*p7Zpr_-3X!^K&~8H<<`+U=b0nsY@e^SJ|6oDKZ72hlU_LRVnePW=Yf5-P61tr zY$HV>81o+hGZzW!R@)YAV2T#0y&ayKC$tvZ|ll2v4{rgQ)U_m66^qs^`iaQ9{<;73w91uY8(yLM*d z$uvNgX7X#Al6$<$$!U$a`WP42*&vI;C1SOGlvgyXW^xGbKSSfJbRa=9No>Ru+Hmvt zB{k|dG24w{uEDCaJVd=GvDGOg!$`V-K)xrfXeMG~mIoa4y0d*28_b14~zy zCkn%>Zxn}DM~m9y#ZuSn<)$-w)0sTI(-|pftBd8%LA`Tum!f@>v`O<pC0PVdn@hXpNWO>)hosE)wjzX$MlY4YaOMI??QwQK{?m^*js68U%rla zfnL73G+B_gn~tp=+~^_q^^K$Jka#A zyjwm#UJhI+c`p!Pg_NCwII7@Ynie1Pq|ru9qsfqyjKvIpES5?OnIx8*VzIj!K3V!o)mey{pLk}$dd1@{bR~Uz%brDY&-d^+r*^`ZLfd2_#Hs1Y^ znqmd`FTj*X@F3I~&F@T3UbMo{?3oMTN+kims`!A8i!*=*HOJt;o+C$%CTHx(Vj?#P zw=XaqwKvHv0L(t2CyDAuU8cff##Y5%Rb3-zp6zh3XJuJG!p0bUJ)kPVHgJSIdAP~O zI7~QlK)67o<}~4ZbEyy?S|$iGCbqOYa;y9R28c4IQy76V3FM)o4mQ0|tKwuev@jjK zpmPO~GI)p$5xX8xI~|7#*MBOlNU(&)Kf4mt8nTlAh63y-bqAj2RkZl7)!_4gp8VBdE%qF#Lac)q6Qr1EF&eQ1DPD-R_U1_mzGM$L zPtpSYJPfSLKtLEn2UfCkDw;f}7AMB$d+(VuVJ z2?Uq!Rhka`{M72XU%A#BHawf{zv2D#qisWr11UE6xAS+O=VF?k@1j$J(Az6f*j|h}sl+sGej)={A#V1iT4z zPKN*BJ!cw<^U6O$V-$P!Ux4f~G)?bDC^}T4+RVQSHClSr->zqdrV(KGe>^w8ed}y; UlWC=4i7%7qSoiHKK9IHl39ld7xBvhE diff --git a/lipton_bbse/__pycache__/labelshift.cpython-311.pyc b/lipton_bbse/__pycache__/labelshift.cpython-311.pyc deleted file mode 100644 index 10decf8e17c94b70d637db0d1e4410e07c9317a4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4410 zcmb_fU2GHC6~5P=@xSx86p|8*vkP6j8-4=p23VFhfLfpx)s}4wC@39!CW*85*xVU! zNHmeW74166h*Sb$R-kpOLW7_ZeW`fZN`0>SklD^dG!jy!YNh2330120sps6WCmtKx zmtM`Bxik0PbMO7mx!*bS`*=J;p#1gkG37TQLjH*_t>SA|UYv%?Eg}<{=E(&r`>v1+ zKG_e_FAE?ASp+G{0gwSX2r?*#K!)Tn$gms%8IhwPqcS~9(y?c#_)NB?AKyf{{cyi{ z1mYGkNCTeTMQ&3OZ!OgD1%J< zj(}3b2fOi|d4`Kf2kd~P=?f*r_UBaH_NgUXTvS+5vjr`;s6Z=|)EOo1x2a^)4{e&! zaFNMm^6LY~4^RF?Q$^(} zb|fcX-d9wW*`m&RK@)CM`yPl2Sq}t1nf=A=HSOcVl|p6WX(&<)t&H6GR()h;s~Os6 zg|<~D))U(+6Ti${$*s^E0W+}03T&zOuLZto3R~BsvC7F8powH%9N{GhH@DHRpn^7< zgQruL{CZo@IoC=$j%gMHkiHKs58c)c%Xs<$uHJs$w<^y2(XRDs_!rF zubgImF#P2pUWJxhM0XvSkU6(*P_M&H0%S}6CBf6%h0&lJ?ye&i&0_p7AVouL0m%kn z5J1UM6t!8Yq_EfE6WfFW-76hn-@y826bL2T0>Y*;8-%A80KpM$LYnA=bgIOP87Y&? z=QKT+W!s>CKMp(uqC&om#A>OPx0m6I8rH0X?|=&O@a_o`$2f!FqpvtS+vF z6J|JJg%g!yPvZkCW6S4i=bjTPrn#unr+oufU$W8Pkbb+_>`PgFsj9C!wk%e~^=SNB zrZ%%8Ezi~Fmfu?w2AyNjZYEqAH`8vYywXe%allA0Zi6#`aRD4um8P~JUCwD~pX0Sc zL6KB=8G$yAf`+|Xn@@IczQZF^&^rOQhJFoDMB}ybm9hHBT6EZq4qMUTHDTD<2rj-= zTvGfS*cV8VZjho(9iL#Yx056I3N`%qVe=1N|ACY3h8+wm(C*pvlBkY!P~^*Ww4eQ# z=#YInQd+a<@m39H?BE+uTgNH!t_=NkDWD^ZHp?7(DKCw-^|Zbbq9d6!%Uq;YGlV&J zZHC~T!w7iBAouVBC;{dhDZnCNxCkwt-Y=g)^v>XYTM8OM?q<)>ft-UFw*^(s6*z4% z^b{T;`0=s*g`8^p^NQ;DVIYquYsMC}V*aBx7PbA#f|3r|K^~N*G^k)$+NkIbt>SUC zg+h5sw}re^$k@JlT|-duBx7+-^N<2O3eifL+DHqgn55^5Y#ch{DbW4|0*tb{5ACEs7Gh^>rvG?E{_wIQde9~|BzH9ZqTNSFpm$Bjc&l)qUk{R1; z#r9&4k;a>=V`lFzt9MsbSdYc8WpBQ*GV|-v8>4r%#`taHj$!tUT0Ns?Y_xI$Ib`Vi z_N&|P4&2$?puh7s-oKr=lUNo0@Z(26eLV5t{G;;^dd#7B{_Oia@Z{3x+FwqX2TxfC zPnqGxZ%O z!E8=T!llnD$zoZD4$0zl@|Ccro1>fN<0`yLFtnixOgfT|Rg%-JSV(>(<#RHcMonsp z9^s09#Aj?RX+mvHjGa&iEQ;%CQFT%V?JSomb2Y6Vz~hRJN_TUm5HNNCJZXWw0}nQa zg2(CuQ0oW>o7|o1T+0-n0K+jNtPO*xkbm`&SfW1Dkj!Yxil!+rr&Qv0E2ZcJ==>7*O*G5=ZOBu+uL9r{3yUn(i8%z1#l311k}IzttO{SO?BzB*-vNhX}1e^3agjCNbI$S#?1bGR{uUTwBHKtU*nhaZ_vLStwniWRXO=z zg7q<&9}|CrHlWsBUj2Kv0e%y*2fht#AZQ;2WgHVv&<@Wd030uz`nzy9qchL2ub;bm z?q+uRV(p?8-MS`hbynW)PDtR~B{4vK0&B05X)4nzA^zj$`2$6-!i@LD2@p5wEwb$U zKWvjtfHwT+yX;sT_u2gKt8gCwI&=g{UY3q7j_}>iE4i6jeY&h>cOS_YA#2f&LdEHV zOVcm`|96GdX?^6g)TgO?`)}{Mv&$MxRS4hz1I!xvGzi4conX=1p_-k9w&SP`^M25_ z9!;&emyqI`(lJ|9%Z1W{jGeYxvp}gCb7kiatt*acolrD8*#AOYklI2c`{T6951J diff --git a/quacc/baseline.py b/quacc/baseline.py index b96f1bc..9a53db7 100644 --- a/quacc/baseline.py +++ b/quacc/baseline.py @@ -201,13 +201,13 @@ def rca_score( ] results = [] for test in protocol(): - [f_prev, t_prev] = test.prevalence() - try: + try: + [f_prev, t_prev] = test.prevalence() test_pred = c_model_predict(test.X) c_model2 = rca.clone_fit(c_model, test.X, test_pred) c_model2_predict = getattr(c_model2, predict_method) val_pred2 = c_model2_predict(validation.X) - rca_score = 1.0 - rca.get_score(val_pred1, val_pred2, validation.y) + rca_score = rca.get_score(val_pred1, val_pred2, validation.y) results.append({k: v for k, v in zip(cols, [f_prev, t_prev, rca_score])}) except ValueError: results.append({k: v for k, v in zip(cols, [f_prev, t_prev, float("nan")])}) @@ -248,7 +248,7 @@ def rca_star_score( c_model2 = rca.clone_fit(c_model, test.X, test_pred) c_model2_predict = getattr(c_model2, predict_method) val2_pred2 = c_model2_predict(validation2.X) - rca_star_score = 1.0 - rca.get_score(val2_pred1, val2_pred2, validation2.y) + rca_star_score = rca.get_score(val2_pred1, val2_pred2, validation2.y) results.append( {k: v for k, v in zip(cols, [f_prev, t_prev, rca_star_score])} ) diff --git a/quacc/dataset.py b/quacc/dataset.py index d009a78..8098966 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -1,3 +1,4 @@ +from operator import index from typing import Tuple import numpy as np from quapy.data.base import LabelledCollection @@ -18,11 +19,29 @@ def get_spambase() -> Tuple[LabelledCollection]: train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP) return train, validation, test +# >>> fetch_rcv1().target_names +# array(['C11', 'C12', 'C13', 'C14', 'C15', 'C151', 'C1511', 'C152', 'C16', +# 'C17', 'C171', 'C172', 'C173', 'C174', 'C18', 'C181', 'C182', +# 'C183', 'C21', 'C22', 'C23', 'C24', 'C31', 'C311', 'C312', 'C313', +# 'C32', 'C33', 'C331', 'C34', 'C41', 'C411', 'C42', 'CCAT', 'E11', +# 'E12', 'E121', 'E13', 'E131', 'E132', 'E14', 'E141', 'E142', +# 'E143', 'E21', 'E211', 'E212', 'E31', 'E311', 'E312', 'E313', +# 'E41', 'E411', 'E51', 'E511', 'E512', 'E513', 'E61', 'E71', 'ECAT', +# 'G15', 'G151', 'G152', 'G153', 'G154', 'G155', 'G156', 'G157', +# 'G158', 'G159', 'GCAT', 'GCRIM', 'GDEF', 'GDIP', 'GDIS', 'GENT', +# 'GENV', 'GFAS', 'GHEA', 'GJOB', 'GMIL', 'GOBIT', 'GODD', 'GPOL', +# 'GPRO', 'GREL', 'GSCI', 'GSPO', 'GTOUR', 'GVIO', 'GVOTE', 'GWEA', +# 'GWELF', 'M11', 'M12', 'M13', 'M131', 'M132', 'M14', 'M141', +# 'M142', 'M143', 'MCAT'], dtype=object) -def get_rcv1(sample_size=100): +def get_rcv1(target:str): + sample_size = qp.environ["SAMPLE_SIZE"] n_train = 23149 dataset = fetch_rcv1() + if target not in dataset.target_names: + raise ValueError("Invalid target") + def dataset_split(data, labels, classes=[0, 1]) -> Tuple[LabelledCollection]: all_train_d, test_d = data[:n_train, :], data[n_train:, :] all_train_l, test_l = labels[:n_train], labels[n_train:] @@ -31,14 +50,13 @@ def get_rcv1(sample_size=100): train, validation = all_train.split_stratified(train_prop=TRAIN_VAL_PROP) return train, validation, test - target_labels = [ - (target, dataset.target[:, ind].toarray().flatten()) - for (ind, target) in enumerate(dataset.target_names) - ] - filtered_target_labels = filter( - lambda _, labels: np.sum(labels[n_train:]) >= sample_size, target_labels - ) - return { - target: dataset_split(dataset.data, labels, classes=[0, 1]) - for (target, labels) in filtered_target_labels - } + target_index = np.where(dataset.target_names == target)[0] + target_labels = dataset.target[:, target_index].toarray().flatten() + + if np.sum(target_labels[n_train:]) < sample_size: + raise ValueError("Target has too few positive samples") + + d = dataset_split(dataset.data, target_labels, classes=[0, 1]) + + return d + diff --git a/quacc/error.py b/quacc/error.py index 90e5701..dfd19bd 100644 --- a/quacc/error.py +++ b/quacc/error.py @@ -8,18 +8,28 @@ def from_name(err_name): else: return qp.error.from_name(err_name) +# def f1(prev): +# # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure +# if prev[0] == 0 and prev[1] == 0 and prev[2] == 0: +# return 1.0 +# elif prev[0] == 0 and prev[1] > 0 and prev[2] == 0: +# return 0.0 +# elif prev[0] == 0 and prev[1] == 0 and prev[2] > 0: +# return float('NaN') +# else: +# recall = prev[0] / (prev[0] + prev[1]) +# precision = prev[0] / (prev[0] + prev[2]) +# return 2 * (precision * recall) / (precision + recall) + def f1(prev): - # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure - if prev[0] == 0 and prev[1] == 0 and prev[2] == 0: + den = (2*prev[3]) + prev[1] + prev[2] + if den == 0: return 1.0 - elif prev[0] == 0 and prev[1] > 0 and prev[2] == 0: - return 0.0 - elif prev[0] == 0 and prev[1] == 0 and prev[2] > 0: - return float('NaN') else: - recall = prev[0] / (prev[0] + prev[1]) - precision = prev[0] / (prev[0] + prev[2]) - return 2 * (precision * recall) / (precision + recall) + return (2*prev[3])/den def f1e(prev): return 1 - f1(prev) + +def mae(prev): + return (prev[1] + prev[2]) / sum(prev) \ No newline at end of file diff --git a/quacc/evaluation.py b/quacc/evaluation.py index b07b4f2..a58f86c 100644 --- a/quacc/evaluation.py +++ b/quacc/evaluation.py @@ -80,55 +80,63 @@ def evaluation_report( protocol: AbstractStochasticSeededProtocol, error_metrics: Iterable[Union[str, Callable]] = "all", aggregate: bool = True, + prevalence: bool = True, ): def _report_columns(err_names): base_cols = list(itertools.product(["base"], ["F", "T"])) prev_cols = list(itertools.product(["true", "estim"], ["TN", "FP", "FN", "TP"])) err_cols = list(itertools.product(["errors"], err_names)) - return base_cols + prev_cols, err_cols + return base_cols, prev_cols, err_cols base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol) if error_metrics == "all": - error_metrics = ["ae", "f1"] + error_metrics = ["mae", "f1"] error_funcs = [ error.from_name(e) if isinstance(e, str) else e for e in error_metrics ] assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function" error_names = [e.__name__ for e in error_funcs] - error_cols = error_names.copy() - if "f1" in error_cols: - error_cols.remove("f1") - error_cols.extend(["f1_true", "f1_estim"]) - if "f1e" in error_cols: - error_cols.remove("f1e") - error_cols.extend(["f1e_true", "f1e_estim"]) + error_cols = [] + for err in error_names: + if err == "mae": + error_cols.extend(["mae_estim", "mae_true"]) + elif err == "f1": + error_cols.extend(["f1_estim", "f1_true"]) + elif err == "f1e": + error_cols.extend(["f1e_estim", "f1e_true"]) + else: + error_cols.append(err) # df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names - prev_cols, err_cols = _report_columns(error_cols) + base_cols, prev_cols, err_cols = _report_columns(error_cols) lst = [] for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs): - series = { - k: v - for (k, v) in zip( - prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0) - ) - } - for error_name, error_metric in zip(error_names, error_funcs): - if error_name == "f1e": - series[("errors", "f1e_true")] = error_metric(true_prev) - series[("errors", "f1e_estim")] = error_metric(estim_prev) - continue - if error_name == "f1": - f1_true, f1_estim = error_metric(true_prev), error_metric(estim_prev) - series[("errors", "f1_true")] = f1_true - series[("errors", "f1_estim")] = f1_estim - continue + if prevalence: + series = { + k: v + for (k, v) in zip( + base_cols + prev_cols, + np.concatenate((base_prev, true_prev, estim_prev), axis=0), + ) + } + df_cols = base_cols + prev_cols + err_cols + else: + series = {k: v for (k, v) in zip(base_cols, base_prev)} + df_cols = base_cols + err_cols - score = error_metric(true_prev, estim_prev) - series[("errors", error_name)] = score + for err in error_cols: + error_funcs = { + "mae_true": lambda: error.mae(true_prev), + "mae_estim": lambda: error.mae(estim_prev), + "f1_true": lambda: error.f1(true_prev), + "f1_estim": lambda: error.f1(estim_prev), + "f1e_true": lambda: error.f1e(true_prev), + "f1e_estim": lambda: error.f1e(estim_prev), + } + series[("errors", err)] = error_funcs[err]() lst.append(series) @@ -136,6 +144,6 @@ def evaluation_report( df = pd.DataFrame( lst, - columns=pd.MultiIndex.from_tuples(prev_cols + err_cols), + columns=pd.MultiIndex.from_tuples(df_cols), ) return df diff --git a/quacc/main.py b/quacc/main.py index 879b3a5..d58a65e 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -2,6 +2,7 @@ import pandas as pd import quapy as qp from quapy.protocol import APP from sklearn.linear_model import LogisticRegression +from quacc import utils import quacc.evaluation as eval import quacc.baseline as baseline @@ -10,7 +11,7 @@ from quacc.estimator import ( MulticlassAccuracyEstimator, ) -from quacc.dataset import get_imdb, get_spambase +from quacc.dataset import get_imdb, get_rcv1, get_spambase qp.environ["SAMPLE_SIZE"] = 100 @@ -109,25 +110,21 @@ def estimate_comparison(): estimator = BinaryQuantifierAccuracyEstimator(model) estimator.fit(validation) - df = eval.evaluation_report(estimator, protocol) + df = eval.evaluation_report(estimator, protocol, prevalence=False) - df_index = [("base", "F"), ("base", "T")] + df = utils.combine_dataframes( + baseline.atc_mc(model, validation, protocol), + baseline.atc_ne(model, validation, protocol), + baseline.doc_feat(model, validation, protocol), + baseline.rca_score(model, validation, protocol), + baseline.rca_star_score(model, validation, protocol), + baseline.bbse_score(model, validation, protocol), + df, + df_index=[("base", "F"), ("base", "T")] + ) - atc_mc_df = baseline.atc_mc(model, validation, protocol) - atc_ne_df = baseline.atc_ne(model, validation, protocol) - doc_feat_df = baseline.doc_feat(model, validation, protocol) - rca_df = baseline.rca_score(model, validation, protocol) - rca_star_df = baseline.rca_star_score(model, validation, protocol) - bbse_df = baseline.bbse_score(model, validation, protocol) - - df = df.join(atc_mc_df.set_index(df_index), on=df_index) - df = df.join(atc_ne_df.set_index(df_index), on=df_index) - df = df.join(doc_feat_df.set_index(df_index), on=df_index) - df = df.join(rca_df.set_index(df_index), on=df_index) - df = df.join(rca_star_df.set_index(df_index), on=df_index) - df = df.join(bbse_df.set_index(df_index), on=df_index) - - print(df.to_string()) + print(df.to_latex(float_format="{:.4f}".format)) + print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format)) def main(): estimate_comparison() diff --git a/quacc/utils.py b/quacc/utils.py new file mode 100644 index 0000000..6da5b39 --- /dev/null +++ b/quacc/utils.py @@ -0,0 +1,31 @@ + +import functools +import pandas as pd + +def combine_dataframes(*dfs, df_index=[]) -> pd.DataFrame: + if len(dfs) < 1: + raise ValueError + if len(dfs) == 1: + return dfs[0] + df = dfs[0] + for ndf in dfs[1:]: + df = df.join(ndf.set_index(df_index), on=df_index) + + return df + + +def avg_group_report(df: pd.DataFrame) -> pd.DataFrame: + def _reduce_func(s1, s2): + return { + (n1, n2): v + s2[(n1, n2)] for ((n1, n2), v) in s1.items() + } + + lst = df.to_dict(orient="records")[1:-1] + summed_series = functools.reduce(_reduce_func, lst) + idx = df.columns.drop([("base", "T"), ("base", "F")]) + avg_report = { + (n1, n2): (v / len(lst)) + for ((n1, n2), v) in summed_series.items() + if n1 != "base" + } + return pd.DataFrame([avg_report], columns=idx) \ No newline at end of file