diff --git a/.gitignore b/.gitignore index 1ae9719..b199a8a 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,6 @@ lipton_bbse/__pycache__/* elsahar19_rca/__pycache__/* *.coverage .coverage -scp_sync.py \ No newline at end of file +scp_sync.py +out/* +output/* \ No newline at end of file diff --git a/TODO.html b/TODO.html index ddfdc17..35dd2f8 100644 --- a/TODO.html +++ b/TODO.html @@ -41,12 +41,12 @@ diff --git a/TODO.md b/TODO.md index 2f6d846..d7012a1 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,6 @@ -- [ ] aggiungere media tabelle -- [ ] plot; 3 tipi (appunti + email + garg) +- [x] aggiungere media tabelle +- [x] plot; 3 tipi (appunti + email + garg) - [ ] sistemare kfcv baseline -- [ ] aggiungere metodo con CC oltre SLD +- [x] aggiungere metodo con CC oltre SLD - [x] prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1) -- [ ] variare parametro recalibration in SLD \ No newline at end of file +- [x] variare parametro recalibration in SLD \ No newline at end of file diff --git a/conf.yaml b/conf.yaml new file mode 100644 index 0000000..50a5dd0 --- /dev/null +++ b/conf.yaml @@ -0,0 +1,71 @@ + +exec: [] + +commons: + - DATASET_NAME: rcv1 + DATASET_TARGET: CCAT + METRICS: + - acc + - f1 + DATASET_N_PREVS: 9 + - DATASET_NAME: imdb + METRICS: + - acc + - f1 + DATASET_N_PREVS: 9 + +confs: + + all_mul_vs_atc: + COMP_ESTIMATORS: + - our_mul_SLD + - our_mul_SLD_nbvs + - our_mul_SLD_bcts + - our_mul_SLD_ts + - our_mul_SLD_vs + - our_mul_CC + - ref + - atc_mc + - atc_ne + + all_bin_vs_atc: + COMP_ESTIMATORS: + - our_bin_SLD + - our_bin_SLD_nbvs + - our_bin_SLD_bcts + - our_bin_SLD_ts + - our_bin_SLD_vs + - our_bin_CC + - ref + - atc_mc + - atc_ne + + best_our_vs_atc: + COMP_ESTIMATORS: + - our_bin_SLD + - our_bin_SLD_bcts + - our_bin_SLD_vs + - our_bin_CC + - our_mul_SLD + - our_mul_SLD_bcts + - our_mul_SLD_vs + - our_mul_CC + - ref + - atc_mc + - atc_ne + + best_our_vs_all: + COMP_ESTIMATORS: + - our_bin_SLD + - our_bin_SLD_bcts + - our_bin_SLD_vs + - our_bin_CC + - our_mul_SLD + - our_mul_SLD_bcts + - our_mul_SLD_vs + - our_mul_CC + - ref + - kfcv + - atc_mc + - atc_ne + - doc_feat diff --git a/out/plot/rcv1_CCAT_10_acc.png b/out/plot/rcv1_CCAT_10_acc.png deleted file mode 100644 index 2994b60..0000000 Binary files a/out/plot/rcv1_CCAT_10_acc.png and /dev/null differ diff --git a/out/plot/rcv1_CCAT_20_acc.png b/out/plot/rcv1_CCAT_20_acc.png deleted file mode 100644 index 83a7991..0000000 Binary files a/out/plot/rcv1_CCAT_20_acc.png and /dev/null differ diff --git a/out/plot/rcv1_CCAT_30_acc.png b/out/plot/rcv1_CCAT_30_acc.png deleted file mode 100644 index 2e34308..0000000 Binary files a/out/plot/rcv1_CCAT_30_acc.png and /dev/null differ diff --git a/out/plot/rcv1_CCAT_40_acc.png b/out/plot/rcv1_CCAT_40_acc.png deleted file mode 100644 index 031feda..0000000 Binary files a/out/plot/rcv1_CCAT_40_acc.png and /dev/null differ diff --git a/out/plot/rcv1_CCAT_50_acc.png b/out/plot/rcv1_CCAT_50_acc.png deleted file mode 100644 index 86d23e7..0000000 Binary files a/out/plot/rcv1_CCAT_50_acc.png and /dev/null differ diff --git a/out/plot/rcv1_CCAT_60_acc.png b/out/plot/rcv1_CCAT_60_acc.png deleted file mode 100644 index 374cb70..0000000 Binary files a/out/plot/rcv1_CCAT_60_acc.png and /dev/null differ diff --git a/out/plot/rcv1_CCAT_70_acc.png b/out/plot/rcv1_CCAT_70_acc.png deleted file mode 100644 index 3af314e..0000000 Binary files a/out/plot/rcv1_CCAT_70_acc.png and /dev/null differ diff --git a/out/plot/rcv1_CCAT_80_acc.png b/out/plot/rcv1_CCAT_80_acc.png deleted file mode 100644 index 2f525d0..0000000 Binary files a/out/plot/rcv1_CCAT_80_acc.png and /dev/null differ diff --git a/out/plot/rcv1_CCAT_90_acc.png b/out/plot/rcv1_CCAT_90_acc.png deleted file mode 100644 index 0150b84..0000000 Binary files a/out/plot/rcv1_CCAT_90_acc.png and /dev/null differ diff --git a/out/rcv1_CCAT.md b/out/rcv1_CCAT.md deleted file mode 100644 index 1eff4e7..0000000 --- a/out/rcv1_CCAT.md +++ /dev/null @@ -1,1955 +0,0 @@ -rcv1_CCAT - -> train: [0.09996662 0.90003338] -> validation: [0.09996662 0.90003338] -> evaluate_bin_sld: 198.301s -> evaluate_mul_sld: 53.156s -> kfcv: 41.095s -> atc_mc: 42.167s -> atc_ne: 41.909s -> doc_feat: 35.796s -> tot: 202.108s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.00480.00400.08660.02430.02430.0830
(0.05, 0.95)0.00600.00720.04410.01340.01340.0407
(0.1, 0.9)0.00840.01030.00320.00700.00700.0036
(0.15, 0.85)0.01270.01720.04180.00900.00900.0450
(0.2, 0.8)0.01840.02460.08410.01680.01680.0872
(0.25, 0.75)0.02310.03180.12460.02390.02390.1276
(0.3, 0.7)0.03130.04260.16780.03340.03340.1706
(0.35, 0.65)0.03920.05360.21100.04220.04220.2137
(0.4, 0.6)0.04180.05630.25280.05410.05410.2555
(0.45, 0.55)0.05270.07150.29660.06220.06220.2991
(0.5, 0.5)0.05690.07710.33830.07490.07490.3407
(0.55, 0.45)0.06370.08670.38170.08470.08470.3840
(0.6, 0.4)0.07270.09990.42500.09430.09430.4272
(0.65, 0.35)0.07780.10620.46620.10400.10400.4683
(0.7, 0.3)0.08250.11180.50990.11310.11310.5119
(0.75, 0.25)0.08790.11970.55190.12170.12170.5537
(0.8, 0.2)0.08870.11920.59450.13340.13340.5963
(0.85, 0.15)0.09260.12690.63680.14260.14260.6384
(0.9, 0.1)0.08870.12500.67910.15280.15280.6806
(0.95, 0.05)0.05010.09610.72270.16140.16140.7241
(1.0, 0.0)0.00040.03580.76310.17040.17040.7643
- - - -> train: [0.19993324 0.80006676] -> validation: [0.20010013 0.79989987] -> evaluate_bin_sld: 199.250s -> evaluate_mul_sld: 55.414s -> kfcv: 41.131s -> atc_mc: 42.125s -> atc_ne: 41.892s -> doc_feat: 35.279s -> tot: 202.707s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.00550.00580.09150.01470.01470.0775
(0.05, 0.95)0.01570.00840.07190.01300.01300.0581
(0.1, 0.9)0.01540.00990.05030.01080.01080.0365
(0.15, 0.85)0.01410.01110.02920.01040.01040.0158
(0.2, 0.8)0.01200.01160.01030.01000.01000.0068
(0.25, 0.75)0.00980.01240.01150.00910.00910.0243
(0.3, 0.7)0.00790.01310.03120.01060.01060.0445
(0.35, 0.65)0.00870.01540.05290.00970.00970.0660
(0.4, 0.6)0.00740.01430.07290.01100.01100.0859
(0.45, 0.55)0.00820.01480.09330.01110.01110.1062
(0.5, 0.5)0.00810.01520.11520.01360.01360.1280
(0.55, 0.45)0.01040.01640.13840.01470.01470.1511
(0.6, 0.4)0.01080.01930.15670.01680.01680.1692
(0.65, 0.35)0.01290.02120.18060.01960.01960.1930
(0.7, 0.3)0.01340.02420.20050.01780.01780.2128
(0.75, 0.25)0.01620.02380.21960.02010.02010.2318
(0.8, 0.2)0.01610.02480.24250.02140.02140.2546
(0.85, 0.15)0.02070.03200.26200.02270.02270.2740
(0.9, 0.1)0.02330.03400.28410.02670.02670.2960
(0.95, 0.05)0.02610.03930.30340.02740.02740.3151
(1.0, 0.0)0.00190.01620.32170.03110.03110.3333
- - - -> train: [0.29989987 0.70010013] -> validation: [0.30006676 0.69993324] -> evaluate_bin_sld: 197.848s -> evaluate_mul_sld: 55.610s -> kfcv: 40.783s -> atc_mc: 42.124s -> atc_ne: 41.370s -> doc_feat: 35.340s -> tot: 199.287s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.00510.00590.05300.00590.00590.0422
(0.05, 0.95)0.01080.00820.04550.00630.00630.0347
(0.1, 0.9)0.01270.01100.03560.00620.00620.0250
(0.15, 0.85)0.01470.01450.02650.00760.00760.0160
(0.2, 0.8)0.01580.01620.01730.00710.00710.0086
(0.25, 0.75)0.01470.01580.00910.00700.00700.0075
(0.3, 0.7)0.01340.01620.00730.00800.00800.0127
(0.35, 0.65)0.01380.01780.01320.01000.01000.0230
(0.4, 0.6)0.01300.01800.02040.00960.00960.0306
(0.45, 0.55)0.01020.01490.02970.01020.01020.0397
(0.5, 0.5)0.00940.01600.04050.01110.01110.0504
(0.55, 0.45)0.00950.01350.05160.01230.01230.0615
(0.6, 0.4)0.00860.01320.05960.01220.01220.0693
(0.65, 0.35)0.01020.01230.07170.01490.01490.0814
(0.7, 0.3)0.00980.01150.07970.01600.01600.0894
(0.75, 0.25)0.01110.01080.08800.01600.01600.0975
(0.8, 0.2)0.01120.00930.09960.02060.02060.1091
(0.85, 0.15)0.01490.01190.10940.01970.01970.1187
(0.9, 0.1)0.01670.01370.11780.02160.02160.1271
(0.95, 0.05)0.01840.01450.12750.02220.02220.1367
(1.0, 0.0)0.00070.00990.13710.02380.02380.1462
- - - -> train: [0.40003338 0.59996662] -> validation: [0.40003338 0.59996662] -> evaluate_bin_sld: 197.597s -> evaluate_mul_sld: 55.556s -> kfcv: 40.650s -> atc_mc: 41.687s -> atc_ne: 41.375s -> doc_feat: 34.998s -> tot: 198.892s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.00130.00480.01940.00710.00710.0126
(0.05, 0.95)0.00760.00840.01840.00710.00710.0111
(0.1, 0.9)0.00920.01070.01610.00780.00780.0093
(0.15, 0.85)0.01270.01490.01340.00700.00700.0077
(0.2, 0.8)0.01830.02000.01100.00660.00660.0075
(0.25, 0.75)0.02080.02300.00900.00750.00750.0069
(0.3, 0.7)0.02350.02600.00800.00760.00760.0073
(0.35, 0.65)0.02730.03060.00650.00790.00790.0095
(0.4, 0.6)0.02960.03350.00740.00720.00720.0099
(0.45, 0.55)0.02830.03130.00800.00850.00850.0116
(0.5, 0.5)0.02670.03170.00870.00850.00850.0147
(0.55, 0.45)0.02730.03310.01310.00860.00860.0196
(0.6, 0.4)0.02390.03200.01360.00820.00820.0202
(0.65, 0.35)0.02080.02900.01710.00840.00840.0241
(0.7, 0.3)0.01860.02880.02130.00840.00840.0281
(0.75, 0.25)0.01580.02610.02190.00900.00900.0288
(0.8, 0.2)0.01300.02350.02690.00890.00890.0338
(0.85, 0.15)0.00840.01800.02840.00830.00830.0352
(0.9, 0.1)0.00570.01340.03220.00920.00920.0390
(0.95, 0.05)0.00500.00910.03390.01010.01010.0406
(1.0, 0.0)0.00070.00640.03790.01060.01060.0447
- - - -> train: [0.5 0.5] -> validation: [0.5 0.5] -> evaluate_bin_sld: 197.283s -> evaluate_mul_sld: 54.736s -> kfcv: 40.375s -> atc_mc: 41.898s -> atc_ne: 41.366s -> doc_feat: 35.145s -> tot: 198.630s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.00040.00350.02570.02890.02890.0344
(0.05, 0.95)0.00750.00850.02240.02530.02530.0310
(0.1, 0.9)0.00810.01220.02050.02390.02390.0292
(0.15, 0.85)0.01020.01480.01800.02050.02050.0267
(0.2, 0.8)0.01390.01980.01650.02110.02110.0248
(0.25, 0.75)0.01940.02450.01410.01700.01700.0224
(0.3, 0.7)0.02300.02870.01370.01640.01640.0222
(0.35, 0.65)0.03090.03380.01320.01680.01680.0210
(0.4, 0.6)0.03500.03710.00970.01440.01440.0164
(0.45, 0.55)0.03580.03900.00860.01250.01250.0150
(0.5, 0.5)0.03690.03860.00730.01220.01220.0138
(0.55, 0.45)0.03730.03980.00710.01100.01100.0128
(0.6, 0.4)0.03680.03980.00640.00850.00850.0103
(0.65, 0.35)0.03570.03850.00740.01030.01030.0105
(0.7, 0.3)0.03190.03700.00670.00820.00820.0086
(0.75, 0.25)0.02980.03580.00790.00660.00660.0070
(0.8, 0.2)0.02350.03020.00730.00830.00830.0069
(0.85, 0.15)0.01540.02440.00970.00770.00770.0066
(0.9, 0.1)0.00830.01570.01080.00820.00820.0069
(0.95, 0.05)0.00550.00980.01310.00800.00800.0066
(1.0, 0.0)0.00070.00460.01450.00880.00880.0082
- - - -> train: [0.59996662 0.40003338] -> validation: [0.59996662 0.40003338] -> evaluate_bin_sld: 194.960s -> evaluate_mul_sld: 53.330s -> kfcv: 40.320s -> atc_mc: 41.904s -> atc_ne: 41.423s -> doc_feat: 35.289s -> tot: 196.151s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.00030.00550.08150.02850.02850.0825
(0.05, 0.95)0.00650.01270.07470.02780.02780.0758
(0.1, 0.9)0.00720.01720.06770.02240.02240.0688
(0.15, 0.85)0.01000.02570.06270.02180.02180.0638
(0.2, 0.8)0.01350.03080.05480.01800.01800.0560
(0.25, 0.75)0.01650.03380.04910.01600.01600.0503
(0.3, 0.7)0.02050.04090.04380.01680.01680.0450
(0.35, 0.65)0.02480.04590.03740.01560.01560.0386
(0.4, 0.6)0.02840.04910.02770.01120.01120.0290
(0.45, 0.55)0.03180.05150.02240.00990.00990.0237
(0.5, 0.5)0.03420.05160.01590.00810.00810.0170
(0.55, 0.45)0.03740.05190.01110.00730.00730.0121
(0.6, 0.4)0.04100.05370.00690.00790.00790.0075
(0.65, 0.35)0.04440.05170.00640.00760.00760.0064
(0.7, 0.3)0.04380.05020.01000.00850.00850.0090
(0.75, 0.25)0.04580.04830.01710.00890.00890.0157
(0.8, 0.2)0.04120.04190.02180.01050.01050.0204
(0.85, 0.15)0.03190.03480.02910.01170.01170.0276
(0.9, 0.1)0.01920.02540.03580.01470.01470.0343
(0.95, 0.05)0.00790.01540.04270.01660.01660.0412
(1.0, 0.0)0.00050.00340.04900.01900.01900.0474
- - - -> train: [0.69993324 0.30006676] -> validation: [0.70010013 0.29989987] -> evaluate_bin_sld: 196.856s -> evaluate_mul_sld: 54.245s -> kfcv: 41.167s -> atc_mc: 42.203s -> atc_ne: 41.565s -> doc_feat: 34.998s -> tot: 198.332s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.00030.00710.15700.06250.06250.1677
(0.05, 0.95)0.00890.01020.14280.05480.05480.1536
(0.1, 0.9)0.00780.01210.13270.05210.05210.1435
(0.15, 0.85)0.00730.01550.12270.05170.05170.1336
(0.2, 0.8)0.00810.01960.10940.04640.04640.1203
(0.25, 0.75)0.00950.02250.10010.04270.04270.1111
(0.3, 0.7)0.01170.02720.08850.04000.04000.0995
(0.35, 0.65)0.01310.03090.07740.03680.03680.0885
(0.4, 0.6)0.01440.03330.06260.03070.03070.0737
(0.45, 0.55)0.01790.03650.05280.02970.02970.0640
(0.5, 0.5)0.01830.03590.04180.02590.02590.0531
(0.55, 0.45)0.01890.03690.03130.02220.02220.0426
(0.6, 0.4)0.02200.03790.02010.01900.01900.0314
(0.65, 0.35)0.02180.03640.01040.01600.01600.0208
(0.7, 0.3)0.02290.03710.00670.01190.01190.0096
(0.75, 0.25)0.02500.03780.01610.01010.01010.0067
(0.8, 0.2)0.02370.03330.02590.00820.00820.0143
(0.85, 0.15)0.02270.02820.03810.00600.00600.0265
(0.9, 0.1)0.01800.02020.04990.00490.00490.0382
(0.95, 0.05)0.00970.01170.06070.00720.00720.0489
(1.0, 0.0)0.00140.00240.07240.01030.01030.0606
- - - -> train: [0.79989987 0.20010013] -> validation: [0.80006676 0.19993324] -> evaluate_bin_sld: 197.725s -> evaluate_mul_sld: 53.526s -> kfcv: 40.971s -> atc_mc: 41.975s -> atc_ne: 41.358s -> doc_feat: 35.091s -> tot: 199.051s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.00090.00820.31480.05710.05710.3213
(0.05, 0.95)0.02970.02230.29250.04920.04920.2991
(0.1, 0.9)0.02830.02090.27330.04930.04930.2800
(0.15, 0.85)0.02470.01820.25280.04470.04470.2596
(0.2, 0.8)0.02160.01560.23280.04070.04070.2397
(0.25, 0.75)0.01700.01360.21360.04250.04250.2205
(0.3, 0.7)0.01460.01260.19410.03840.03840.2012
(0.35, 0.65)0.01250.01130.17340.03310.03310.1806
(0.4, 0.6)0.01130.01100.15100.02720.02720.1583
(0.45, 0.55)0.00930.01350.13280.02470.02470.1402
(0.5, 0.5)0.00880.01350.11310.02220.02220.1206
(0.55, 0.45)0.00920.01550.09190.02070.02070.0995
(0.6, 0.4)0.00920.01730.07420.01900.01900.0819
(0.65, 0.35)0.00870.01780.05440.01610.01610.0621
(0.7, 0.3)0.00930.01970.03230.01240.01240.0401
(0.75, 0.25)0.01010.02180.01140.00930.00930.0187
(0.8, 0.2)0.01170.02080.00980.00880.00880.0063
(0.85, 0.15)0.01030.01780.02850.00640.00640.0204
(0.9, 0.1)0.01030.01640.04800.00620.00620.0398
(0.95, 0.05)0.00920.01170.06840.00710.00710.0601
(1.0, 0.0)0.00110.00190.08870.00970.00970.0803
- - - -> train: [0.90003338 0.09996662] -> validation: [0.90003338 0.09996662] -> evaluate_bin_sld: 201.315s -> evaluate_mul_sld: 50.974s -> kfcv: 40.175s -> atc_mc: 41.663s -> atc_ne: 41.058s -> doc_feat: 35.055s -> tot: 202.573s - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
binmulkfcvatc_mcatc_nedoc_feat
(0.0, 1.0)0.03210.01840.64210.13360.13360.6454
(0.05, 0.95)0.08350.07290.60560.12440.12440.6090
(0.1, 0.9)0.10800.09760.57030.12040.12040.5739
(0.15, 0.85)0.11540.09710.53540.11470.11470.5390
(0.2, 0.8)0.10810.09160.50070.10640.10640.5045
(0.25, 0.75)0.10320.08300.46320.10050.10050.4671
(0.3, 0.7)0.09450.07750.42740.09160.09160.4313
(0.35, 0.65)0.09660.07090.39140.08430.08430.3954
(0.4, 0.6)0.07950.06390.35430.07480.07480.3584
(0.45, 0.55)0.07350.05330.32100.07280.07280.3253
(0.5, 0.5)0.07160.04730.28290.06330.06330.2873
(0.55, 0.45)0.05500.03930.24650.05680.05680.2509
(0.6, 0.4)0.05050.03170.21170.05090.05090.2162
(0.65, 0.35)0.04030.02260.17410.04380.04380.1788
(0.7, 0.3)0.03720.01780.13870.03480.03480.1434
(0.75, 0.25)0.02620.01220.10090.02560.02560.1057
(0.8, 0.2)0.02480.01100.06510.01940.01940.0701
(0.85, 0.15)0.01810.00750.02980.01280.01280.0348
(0.9, 0.1)0.01290.00930.00690.00800.00800.0037
(0.95, 0.05)0.00770.00850.04260.00460.00460.0373
(1.0, 0.0)0.00100.00100.07890.00880.00880.0735
- - - diff --git a/poetry.lock b/poetry.lock index 7cb982e..7d7365d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -956,6 +956,65 @@ files = [ {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"}, ] +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "quapy" version = "0.1.7" @@ -1164,4 +1223,4 @@ test = ["pytest", "pytest-cov"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "72e3afd9a24b88fc8a8f5f55e1c408f65090fce9015a442f6f41638191276b6f" +content-hash = "0ce0e6b058900e7db2939e7eb047a1f868c88de67def370c1c1fa0ba532df0b0" diff --git a/pyproject.toml b/pyproject.toml index 9805ca9..d9ce79a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ python = "^3.11" quapy = "^0.1.7" pandas = "^2.0.3" jinja2 = "^3.1.2" +pyyaml = "^6.0.1" [tool.poetry.scripts] main = "quacc.main:main" diff --git a/quacc/environ.py b/quacc/environ.py index cc2f13c..1177964 100644 --- a/quacc/environ.py +++ b/quacc/environ.py @@ -1,21 +1,33 @@ -from pathlib import Path +import yaml defalut_env = { "DATASET_NAME": "rcv1", "DATASET_TARGET": "CCAT", + "METRICS": ["acc", "f1"], "COMP_ESTIMATORS": [ - "OUR_BIN_SLD", - "OUR_MUL_SLD", - "KFCV", - "ATC_MC", - "ATC_NE", - "DOC_FEAT", - # "RCA", - # "RCA_STAR", + "our_bin_SLD", + "our_bin_SLD_nbvs", + "our_bin_SLD_bcts", + "our_bin_SLD_ts", + "our_bin_SLD_vs", + "our_bin_CC", + "our_mul_SLD", + "our_mul_SLD_nbvs", + "our_mul_SLD_bcts", + "our_mul_SLD_ts", + "our_mul_SLD_vs", + "our_mul_CC", + "ref", + "kfcv", + "atc_mc", + "atc_ne", + "doc_feat", + "rca", + "rca_star", ], "DATASET_N_PREVS": 9, - "OUT_DIR": Path("out"), - "PLOT_OUT_DIR": Path("out/plot"), + "OUT_DIR_NAME": "output", + "PLOT_DIR_NAME": "plot", "PROTOCOL_N_PREVS": 21, "PROTOCOL_REPEATS": 100, "SAMPLE_SIZE": 1000, @@ -24,8 +36,37 @@ defalut_env = { class Environ: def __init__(self, **kwargs): - for k, v in kwargs.items(): + self.exec = [] + self.confs = {} + self.__setdict(kwargs) + + def __setdict(self, d): + for k, v in d.items(): self.__setattr__(k, v) + def load_conf(self): + with open("conf.yaml", "r") as f: + confs = yaml.safe_load(f) + + for common in confs["commons"]: + name = common["DATASET_NAME"] + if "DATASET_TARGET" in common: + name += "_" + common["DATASET_TARGET"] + for k, d in confs["confs"].items(): + _k = f"{name}_{k}" + self.confs[_k] = common | d + self.exec.append(_k) + + if "exec" in confs: + if len(confs["exec"]) > 0: + self.exec = confs["exec"] + + def __iter__(self): + self.load_conf() + for _conf in self.exec: + if _conf in self.confs: + self.__setdict(self.confs[_conf]) + yield _conf + env = Environ(**defalut_env) diff --git a/quacc/error.py b/quacc/error.py index 116cc42..6ed7dd4 100644 --- a/quacc/error.py +++ b/quacc/error.py @@ -1,13 +1,15 @@ import quapy as qp + def from_name(err_name): - if err_name == 'f1e': + if err_name == "f1e": return f1e - elif err_name == 'f1': + elif err_name == "f1": return f1 else: return qp.error.from_name(err_name) - + + # def f1(prev): # # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure # if prev[0] == 0 and prev[1] == 0 and prev[2] == 0: @@ -18,18 +20,21 @@ def from_name(err_name): # return float('NaN') # else: # recall = prev[0] / (prev[0] + prev[1]) -# precision = prev[0] / (prev[0] + prev[2]) +# precision = prev[0] / (prev[0] + prev[2]) # return 2 * (precision * recall) / (precision + recall) + def f1(prev): - den = (2*prev[3]) + prev[1] + prev[2] + den = (2 * prev[3]) + prev[1] + prev[2] if den == 0: return 0.0 else: - return (2*prev[3])/den + return (2 * prev[3]) / den + def f1e(prev): return 1 - f1(prev) + def acc(prev): - return (prev[1] + prev[2]) / sum(prev) \ No newline at end of file + return (prev[0] + prev[3]) / sum(prev) diff --git a/quacc/estimator.py b/quacc/estimator.py index 4516b6d..2f9a92c 100644 --- a/quacc/estimator.py +++ b/quacc/estimator.py @@ -1,9 +1,9 @@ -from abc import abstractmethod import math +from abc import abstractmethod import numpy as np from quapy.data import LabelledCollection -from quapy.method.aggregative import SLD +from quapy.method.aggregative import CC, SLD from sklearn.base import BaseEstimator from sklearn.linear_model import LogisticRegression from sklearn.model_selection import cross_val_predict @@ -15,7 +15,7 @@ class AccuracyEstimator: def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection: if not pred_proba: pred_proba = self.c_model.predict_proba(base.X) - return ExtendedCollection.extend_collection(base, pred_proba) + return ExtendedCollection.extend_collection(base, pred_proba), pred_proba @abstractmethod def fit(self, train: LabelledCollection | ExtendedCollection): @@ -27,9 +27,15 @@ class AccuracyEstimator: class MulticlassAccuracyEstimator(AccuracyEstimator): - def __init__(self, c_model: BaseEstimator): + def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs): self.c_model = c_model - self.q_model = SLD(LogisticRegression()) + if q_model == "SLD": + available_args = ["recalib"] + sld_args = {k: v for k, v in kwargs.items() if k in available_args} + self.q_model = SLD(LogisticRegression(), **sld_args) + elif q_model == "CC": + self.q_model = CC(LogisticRegression()) + self.e_train = None def fit(self, train: LabelledCollection | ExtendedCollection): @@ -67,10 +73,17 @@ class MulticlassAccuracyEstimator(AccuracyEstimator): class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): - def __init__(self, c_model: BaseEstimator): + def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs): self.c_model = c_model - self.q_model_0 = SLD(LogisticRegression()) - self.q_model_1 = SLD(LogisticRegression()) + if q_model == "SLD": + available_args = ["recalib"] + sld_args = {k: v for k, v in kwargs.items() if k in available_args} + self.q_model_0 = SLD(LogisticRegression(), **sld_args) + self.q_model_1 = SLD(LogisticRegression(), **sld_args) + elif q_model == "CC": + self.q_model_0 = CC(LogisticRegression()) + self.q_model_1 = CC(LogisticRegression()) + self.e_train = None def fit(self, train: LabelledCollection | ExtendedCollection): @@ -83,7 +96,7 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train) elif isinstance(train, ExtendedCollection): - self.e_train = train + self.e_train = train self.n_classes = self.e_train.n_classes [e_train_0, e_train_1] = self.e_train.split_by_pred() diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py index f4e969d..e36a492 100644 --- a/quacc/evaluation/baseline.py +++ b/quacc/evaluation/baseline.py @@ -34,14 +34,14 @@ def kfcv( # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(prefix="kfcv") + report = EvaluationReport(name="kfcv") for test in protocol(): test_preds = c_model_predict(test.X) meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds)) meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds)) report.append_row( test.prevalence(), - acc_score=(1.0 - acc_score), + acc_score=acc_score, f1_score=f1_score, acc=meta_acc, f1=meta_f1, @@ -57,13 +57,13 @@ def reference( ): protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") c_model_predict = getattr(c_model, "predict_proba") - report = EvaluationReport(prefix="ref") + report = EvaluationReport(name="ref") for test in protocol(): test_probs = c_model_predict(test.X) test_preds = np.argmax(test_probs, axis=-1) report.append_row( test.prevalence(), - acc_score=(1 - metrics.accuracy_score(test.y, test_preds)), + acc_score=metrics.accuracy_score(test.y, test_preds), f1_score=metrics.f1_score(test.y, test_preds), ) @@ -89,7 +89,7 @@ def atc_mc( # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(prefix="atc_mc") + report = EvaluationReport(name="atc_mc") for test in protocol(): ## Load OOD test data probs test_probs = c_model_predict(test.X) @@ -102,7 +102,7 @@ def atc_mc( report.append_row( test.prevalence(), acc=meta_acc, - acc_score=1.0 - atc_accuracy, + acc_score=atc_accuracy, f1_score=f1_score, f1=meta_f1, ) @@ -129,7 +129,7 @@ def atc_ne( # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(prefix="atc_ne") + report = EvaluationReport(name="atc_ne") for test in protocol(): ## Load OOD test data probs test_probs = c_model_predict(test.X) @@ -142,7 +142,7 @@ def atc_ne( report.append_row( test.prevalence(), acc=meta_acc, - acc_score=(1.0 - atc_accuracy), + acc_score=atc_accuracy, f1_score=f1_score, f1=meta_f1, ) @@ -182,14 +182,14 @@ def doc_feat( # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(prefix="doc_feat") + report = EvaluationReport(name="doc_feat") for test in protocol(): test_probs = c_model_predict(test.X) test_preds = np.argmax(test_probs, axis=-1) test_scores = np.max(test_probs, axis=-1) score = (v1acc + doc.get_doc(val_scores, test_scores)) / 100.0 meta_acc = abs(score - metrics.accuracy_score(test.y, test_preds)) - report.append_row(test.prevalence(), acc=meta_acc, acc_score=(1.0 - score)) + report.append_row(test.prevalence(), acc=meta_acc, acc_score=score) return report @@ -206,17 +206,15 @@ def rca_score( # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(prefix="rca") + report = EvaluationReport(name="rca") for test in protocol(): try: test_pred = c_model_predict(test.X) c_model2 = rca.clone_fit(c_model, test.X, test_pred) c_model2_predict = getattr(c_model2, predict_method) val_pred2 = c_model2_predict(validation.X) - rca_score = rca.get_score(val_pred1, val_pred2, validation.y) - meta_score = abs( - rca_score - (1 - metrics.accuracy_score(test.y, test_pred)) - ) + rca_score = 1.0 - rca.get_score(val_pred1, val_pred2, validation.y) + meta_score = abs(rca_score - metrics.accuracy_score(test.y, test_pred)) report.append_row(test.prevalence(), acc=meta_score, acc_score=rca_score) except ValueError: report.append_row( @@ -244,17 +242,15 @@ def rca_star_score( # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(prefix="rca_star") + report = EvaluationReport(name="rca_star") for test in protocol(): try: test_pred = c_model_predict(test.X) c_model2 = rca.clone_fit(c_model, test.X, test_pred) c_model2_predict = getattr(c_model2, predict_method) val2_pred2 = c_model2_predict(validation2.X) - rca_star_score = rca.get_score(val2_pred1, val2_pred2, validation2.y) - meta_score = abs( - rca_star_score - (1 - metrics.accuracy_score(test.y, test_pred)) - ) + rca_star_score = 1.0 - rca.get_score(val2_pred1, val2_pred2, validation2.y) + meta_score = abs(rca_star_score - metrics.accuracy_score(test.y, test_pred)) report.append_row( test.prevalence(), acc=meta_score, acc_score=rca_star_score ) diff --git a/quacc/evaluation/comp.py b/quacc/evaluation/comp.py index ccc4e18..b8c403b 100644 --- a/quacc/evaluation/comp.py +++ b/quacc/evaluation/comp.py @@ -1,5 +1,6 @@ import multiprocessing import time +import traceback from typing import List import pandas as pd @@ -19,14 +20,25 @@ pd.set_option("display.float_format", "{:.4f}".format) class CompEstimator: __dict = { - "OUR_BIN_SLD": method.evaluate_bin_sld, - "OUR_MUL_SLD": method.evaluate_mul_sld, - "KFCV": baseline.kfcv, - "ATC_MC": baseline.atc_mc, - "ATC_NE": baseline.atc_ne, - "DOC_FEAT": baseline.doc_feat, - "RCA": baseline.rca_score, - "RCA_STAR": baseline.rca_star_score, + "our_bin_SLD": method.evaluate_bin_sld, + "our_mul_SLD": method.evaluate_mul_sld, + "our_bin_SLD_nbvs": method.evaluate_bin_sld_nbvs, + "our_mul_SLD_nbvs": method.evaluate_mul_sld_nbvs, + "our_bin_SLD_bcts": method.evaluate_bin_sld_bcts, + "our_mul_SLD_bcts": method.evaluate_mul_sld_bcts, + "our_bin_SLD_ts": method.evaluate_bin_sld_ts, + "our_mul_SLD_ts": method.evaluate_mul_sld_ts, + "our_bin_SLD_vs": method.evaluate_bin_sld_vs, + "our_mul_SLD_vs": method.evaluate_mul_sld_vs, + "our_bin_CC": method.evaluate_bin_cc, + "our_mul_CC": method.evaluate_mul_cc, + "ref": baseline.reference, + "kfcv": baseline.kfcv, + "atc_mc": baseline.atc_mc, + "atc_ne": baseline.atc_ne, + "doc_feat": baseline.doc_feat, + "rca": baseline.rca_score, + "rca_star": baseline.rca_star_score, } def __class_getitem__(cls, e: str | List[str]): @@ -55,7 +67,17 @@ def fit_and_estimate(_estimate, train, validation, test): test, n_prevalences=env.PROTOCOL_N_PREVS, repeats=env.PROTOCOL_REPEATS ) start = time.time() - result = _estimate(model, validation, protocol) + try: + result = _estimate(model, validation, protocol) + except Exception as e: + print(f"Method {_estimate.__name__} failed.") + traceback(e) + return { + "name": _estimate.__name__, + "result": None, + "time": 0, + } + end = time.time() print(f"{_estimate.__name__}: {end-start:.2f}s") @@ -69,22 +91,33 @@ def fit_and_estimate(_estimate, train, validation, test): def evaluate_comparison( dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"] ) -> EvaluationReport: - with multiprocessing.Pool(8) as pool: + with multiprocessing.Pool(len(estimators)) as pool: dr = DatasetReport(dataset.name) for d in dataset(): print(f"train prev.: {d.train_prev}") start = time.time() tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]] results = [pool.apply_async(fit_and_estimate, t) for t in tasks] - results = list(map(lambda r: r.get(), results)) + + results_got = [] + for _r in results: + try: + r = _r.get() + if r["result"] is not None: + results_got.append(r) + except Exception as e: + print(e) + er = EvaluationReport.combine_reports( - *list(map(lambda r: r["result"], results)), name=dataset.name + *[r["result"] for r in results_got], + name=dataset.name, + train_prev=d.train_prev, + valid_prev=d.validation_prev, ) - times = {r["name"]: r["time"] for r in results} + times = {r["name"]: r["time"] for r in results_got} end = time.time() times["tot"] = end - start er.times = times - er.train_prevs = d.prevs dr.add(er) print() diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index e42f203..67f8878 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -1,3 +1,5 @@ +import numpy as np +import sklearn.metrics as metrics from quapy.data import LabelledCollection from quapy.protocol import ( AbstractStochasticSeededProtocol, @@ -22,15 +24,17 @@ def estimate( # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - base_prevs, true_prevs, estim_prevs = [], [], [] + base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], [] for sample in protocol(): - e_sample = estimator.extend(sample) + e_sample, pred_proba = estimator.extend(sample) estim_prev = estimator.estimate(e_sample.X, ext=True) base_prevs.append(sample.prevalence()) true_prevs.append(e_sample.prevalence()) estim_prevs.append(estim_prev) + pred_probas.append(pred_proba) + labels.append(sample.y) - return base_prevs, true_prevs, estim_prevs + return base_prevs, true_prevs, estim_prevs, pred_probas, labels def evaluation_report( @@ -38,16 +42,21 @@ def evaluation_report( protocol: AbstractStochasticSeededProtocol, method: str, ) -> EvaluationReport: - base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol) - report = EvaluationReport(prefix=method) + base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate( + estimator, protocol + ) + report = EvaluationReport(name=method) - for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs): + for base_prev, true_prev, estim_prev, pred_proba, label in zip( + base_prevs, true_prevs, estim_prevs, pred_probas, labels + ): + pred = np.argmax(pred_proba, axis=-1) acc_score = error.acc(estim_prev) f1_score = error.f1(estim_prev) report.append_row( base_prev, - acc_score=1.0 - acc_score, - acc=abs(error.acc(true_prev) - acc_score), + acc_score=acc_score, + acc=abs(metrics.accuracy_score(label, pred) - acc_score), f1_score=f1_score, f1=abs(error.f1(true_prev) - f1_score), ) @@ -60,13 +69,18 @@ def evaluate( validation: LabelledCollection, protocol: AbstractStochasticSeededProtocol, method: str, + q_model: str, + **kwargs, ): estimator: AccuracyEstimator = { "bin": BinaryQuantifierAccuracyEstimator, "mul": MulticlassAccuracyEstimator, - }[method](c_model) + }[method](c_model, q_model=q_model, **kwargs) estimator.fit(validation) - return evaluation_report(estimator, protocol, method) + _method = f"{method}_{q_model}" + for k, v in kwargs.items(): + _method += f"_{v}" + return evaluation_report(estimator, protocol, _method) def evaluate_bin_sld( @@ -74,7 +88,7 @@ def evaluate_bin_sld( validation: LabelledCollection, protocol: AbstractStochasticSeededProtocol, ) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin") + return evaluate(c_model, validation, protocol, "bin", "SLD") def evaluate_mul_sld( @@ -82,4 +96,84 @@ def evaluate_mul_sld( validation: LabelledCollection, protocol: AbstractStochasticSeededProtocol, ) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul") + return evaluate(c_model, validation, protocol, "mul", "SLD") + + +def evaluate_bin_sld_nbvs( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="nbvs") + + +def evaluate_mul_sld_nbvs( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="nbvs") + + +def evaluate_bin_sld_bcts( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="bcts") + + +def evaluate_mul_sld_bcts( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="bcts") + + +def evaluate_bin_sld_ts( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="ts") + + +def evaluate_mul_sld_ts( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="ts") + + +def evaluate_bin_sld_vs( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="vs") + + +def evaluate_mul_sld_vs( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="vs") + + +def evaluate_bin_cc( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "CC") + + +def evaluate_mul_cc( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, +) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "CC") diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index ff8862f..3d14203 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -1,22 +1,24 @@ -import statistics as stats +from pathlib import Path from typing import List, Tuple import numpy as np import pandas as pd from quacc import plot +from quacc.environ import env from quacc.utils import fmt_line_md class EvaluationReport: - def __init__(self, prefix=None): + def __init__(self, name=None): self._prevs = [] self._dict = {} self._g_prevs = None self._g_dict = None - self.name = prefix if prefix is not None else "default" + self.name = name if name is not None else "default" self.times = {} - self.train_prevs = {} + self.train_prev = None + self.valid_prev = None self.target = "default" def append_row(self, base: np.ndarray | Tuple, **row): @@ -34,23 +36,40 @@ class EvaluationReport: def columns(self): return self._dict.keys() - def groupby_prevs(self, metric: str = None): + def group_by_prevs(self, metric: str = None): if self._g_dict is None: self._g_prevs = [] self._g_dict = {k: [] for k in self._dict.keys()} - last_end = 0 - for ind, bp in enumerate(self._prevs): - if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]: - continue + for col, vals in self._dict.items(): + col_grouped = {} + for bp, v in zip(self._prevs, vals): + if bp not in col_grouped: + col_grouped[bp] = [] + col_grouped[bp].append(v) - self._g_prevs.append(bp) - for col in self._dict.keys(): - self._g_dict[col].append( - stats.mean(self._dict[col][last_end : ind + 1]) - ) + self._g_dict[col] = [ + vs + for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1]) + ] - last_end = ind + 1 + self._g_prevs = sorted( + [(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()], + key=lambda bp: bp[1], + ) + + # last_end = 0 + # for ind, bp in enumerate(self._prevs): + # if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]: + # continue + + # self._g_prevs.append(bp) + # for col in self._dict.keys(): + # self._g_dict[col].append( + # stats.mean(self._dict[col][last_end : ind + 1]) + # ) + + # last_end = ind + 1 filtered_g_dict = self._g_dict if metric is not None: @@ -60,30 +79,83 @@ class EvaluationReport: return self._g_prevs, filtered_g_dict + def avg_by_prevs(self, metric: str = None): + g_prevs, g_dict = self.group_by_prevs(metric=metric) + + a_dict = {} + for col, vals in g_dict.items(): + a_dict[col] = [np.mean(vs) for vs in vals] + + return g_prevs, a_dict + + def avg_all(self, metric: str = None): + f_dict = self._dict + if metric is not None: + f_dict = {c1: ls for ((c0, c1), ls) in self._dict.items() if c0 == metric} + + a_dict = {} + for col, vals in f_dict.items(): + a_dict[col] = [np.mean(vals)] + + return a_dict + def get_dataframe(self, metric="acc"): - g_prevs, g_dict = self.groupby_prevs(metric=metric) + g_prevs, g_dict = self.avg_by_prevs(metric=metric) + a_dict = self.avg_all(metric=metric) + for col in g_dict.keys(): + g_dict[col].extend(a_dict[col]) return pd.DataFrame( g_dict, - index=g_prevs, + index=g_prevs + ["tot"], columns=g_dict.keys(), ) - def get_plot(self, mode="delta", metric="acc"): - g_prevs, g_dict = self.groupby_prevs(metric=metric) - t_prev = int(round(self.train_prevs["train"][0] * 100)) - title = f"{self.name}_{t_prev}_{metric}" - plot.plot_delta(g_prevs, g_dict, metric, title) + def get_plot(self, mode="delta", metric="acc") -> Path: + if mode == "delta": + g_prevs, g_dict = self.group_by_prevs(metric=metric) + return plot.plot_delta( + g_prevs, + g_dict, + metric=metric, + name=self.name, + train_prev=self.train_prev, + ) + elif mode == "diagonal": + _, g_dict = self.avg_by_prevs(metric=metric + "_score") + f_dict = {k: v for k, v in g_dict.items() if k != "ref"} + referece = g_dict["ref"] + return plot.plot_diagonal( + referece, + f_dict, + metric=metric, + name=self.name, + train_prev=self.train_prev, + ) + elif mode == "shift": + g_prevs, g_dict = self.avg_by_prevs(metric=metric) + return plot.plot_shift( + g_prevs, + g_dict, + metric=metric, + name=self.name, + train_prev=self.train_prev, + ) def to_md(self, *metrics): res = "" - for k, v in self.train_prevs.items(): - res += fmt_line_md(f"{k}: {str(v)}") + res += fmt_line_md(f"train: {str(self.train_prev)}") + res += fmt_line_md(f"validation: {str(self.valid_prev)}") for k, v in self.times.items(): res += fmt_line_md(f"{k}: {v:.3f}s") res += "\n" for m in metrics: res += self.get_dataframe(metric=m).to_html() + "\n\n" - self.get_plot(metric=m) + op_delta = self.get_plot(mode="delta", metric=m) + res += f"![plot_delta]({str(op_delta.relative_to(env.OUT_DIR))})\n" + op_diag = self.get_plot(mode="diagonal", metric=m) + res += f"![plot_diagonal]({str(op_diag.relative_to(env.OUT_DIR))})\n" + op_shift = self.get_plot(mode="shift", metric=m) + res += f"![plot_shift]({str(op_shift.relative_to(env.OUT_DIR))})\n" return res @@ -91,8 +163,9 @@ class EvaluationReport: if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)): raise ValueError("other has not same base prevalences of self") - if len(set(self._dict.keys()).intersection(set(other._dict.keys()))) > 0: - raise ValueError("self and other have matching keys") + inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys())) + if len(inters_keys) > 0: + raise ValueError(f"self and other have matching keys {str(inters_keys)}.") report = EvaluationReport() report._prevs = self._prevs @@ -100,12 +173,14 @@ class EvaluationReport: return report @staticmethod - def combine_reports(*args, name="default"): + def combine_reports(*args, name="default", train_prev=None, valid_prev=None): er = args[0] for r in args[1:]: er = er.merge(r) er.name = name + er.train_prev = train_prev + er.valid_prev = valid_prev return er diff --git a/quacc/main.py b/quacc/main.py index c900a98..6c2cc4c 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -1,16 +1,39 @@ +import os +import shutil +from pathlib import Path + import quacc.evaluation.comp as comp from quacc.dataset import Dataset from quacc.environ import env +def create_out_dir(dir_name): + dir_path = Path(env.OUT_DIR_NAME) / dir_name + env.OUT_DIR = dir_path + shutil.rmtree(dir_path, ignore_errors=True) + os.mkdir(dir_path) + plot_dir_path = dir_path / "plot" + env.PLOT_OUT_DIR = plot_dir_path + os.mkdir(plot_dir_path) + + def estimate_comparison(): - dataset = Dataset( - env.DATASET_NAME, target=env.DATASET_TARGET, n_prevalences=env.DATASET_N_PREVS - ) - output_path = env.OUT_DIR / f"{dataset.name}.md" - with open(output_path, "w") as f: - dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS) - f.write(dr.to_md("acc")) + for conf in env: + create_out_dir(conf) + dataset = Dataset( + env.DATASET_NAME, + target=env.DATASET_TARGET, + n_prevalences=env.DATASET_N_PREVS, + ) + output_path = env.OUT_DIR / f"{dataset.name}.md" + try: + dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS) + for m in env.METRICS: + output_path = env.OUT_DIR / f"{conf}_{m}.md" + with open(output_path, "w") as f: + f.write(dr.to_md(m)) + except Exception as e: + print(f"Configuration {conf} failed. {e}") # print(df.to_latex(float_format="{:.4f}".format)) # print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format)) diff --git a/quacc/plot.py b/quacc/plot.py index 79977d7..93170f2 100644 --- a/quacc/plot.py +++ b/quacc/plot.py @@ -1,16 +1,191 @@ +from pathlib import Path + import matplotlib.pyplot as plt +import numpy as np from quacc.environ import env -def plot_delta(base_prevs, dict_vals, metric, title): - fig, ax = plt.subplots() +def _get_markers(n: int): + return [ + "o", + "v", + "x", + "+", + "s", + "D", + "p", + "h", + "*", + "^", + ][:n] - base_prevs = [f for f, p in base_prevs] + +def plot_delta( + base_prevs, + dict_vals, + *, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, +) -> Path: + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"delta_{name}_{t_prev_pos}_{metric}" + else: + title = f"delta_{name}_{metric}" + + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + + NUM_COLORS = len(dict_vals) + cm = plt.get_cmap("tab10") + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + ax.set_prop_cycle( + color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)], + ) + + base_prevs = [bp[pos_class] for bp in base_prevs] for method, deltas in dict_vals.items(): + avg = np.array([np.mean(d, axis=-1) for d in deltas]) + # std = np.array([np.std(d, axis=-1) for d in deltas]) ax.plot( base_prevs, + avg, + label=method, + linestyle="-", + marker="o", + markersize=3, + zorder=2, + ) + # ax.fill_between(base_prevs, avg - std, avg + std, alpha=0.25) + + ax.set(xlabel="test prevalence", ylabel=metric, title=title) + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + output_path = env.PLOT_OUT_DIR / f"{title}.png" + fig.savefig(output_path, bbox_inches="tight") + + return output_path + + +def plot_diagonal( + reference, + dict_vals, + *, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, +): + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"diagonal_{name}_{t_prev_pos}_{metric}" + else: + title = f"diagonal_{name}_{metric}" + + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + + NUM_COLORS = len(dict_vals) + cm = plt.get_cmap("tab10") + ax.set_prop_cycle( + marker=_get_markers(NUM_COLORS) * 2, + color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)] * 2, + ) + + reference = np.array(reference) + x_ticks = np.unique(reference) + x_ticks.sort() + + for _, deltas in dict_vals.items(): + deltas = np.array(deltas) + ax.plot( + reference, deltas, + linestyle="None", + markersize=3, + zorder=2, + ) + + for method, deltas in dict_vals.items(): + deltas = np.array(deltas) + x_interp = x_ticks[[0, -1]] + y_interp = np.interp(x_interp, reference, deltas) + ax.plot( + x_interp, + y_interp, + label=method, + linestyle="-", + markersize="0", + zorder=1, + ) + + ax.set(xlabel="test prevalence", ylabel=metric, title=title) + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + output_path = env.PLOT_OUT_DIR / f"{title}.png" + fig.savefig(output_path, bbox_inches="tight") + return output_path + + +def plot_shift( + base_prevs, + dict_vals, + *, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, +) -> Path: + if train_prev is None: + raise AttributeError("train_prev cannot be None.") + + train_prev = train_prev[pos_class] + t_prev_pos = int(round(train_prev * 100)) + title = f"shift_{name}_{t_prev_pos}_{metric}" + + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + + NUM_COLORS = len(dict_vals) + cm = plt.get_cmap("tab10") + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + ax.set_prop_cycle( + color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)], + ) + + base_prevs = np.around( + [abs(bp[pos_class] - train_prev) for bp in base_prevs], decimals=2 + ) + for method, deltas in dict_vals.items(): + delta_bins = {} + for bp, delta in zip(base_prevs, deltas): + if bp not in delta_bins: + delta_bins[bp] = [] + delta_bins[bp].append(delta) + + bp_unique, delta_avg = zip( + *sorted( + {k: np.mean(v) for k, v in delta_bins.items()}.items(), + key=lambda db: db[0], + ) + ) + + ax.plot( + bp_unique, + delta_avg, label=method, linestyle="-", marker="o", @@ -19,8 +194,10 @@ def plot_delta(base_prevs, dict_vals, metric, title): ) ax.set(xlabel="test prevalence", ylabel=metric, title=title) - # ax.set_ylim(0, 1) - # ax.set_xlim(0, 1) - ax.legend() + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) output_path = env.PLOT_OUT_DIR / f"{title}.png" - plt.savefig(output_path) + fig.savefig(output_path, bbox_inches="tight") + + return output_path