diff --git a/baselines/atc.py b/baselines/atc.py index 9e27706..744c284 100644 --- a/baselines/atc.py +++ b/baselines/atc.py @@ -1,41 +1,44 @@ -import numpy as np +import numpy as np from sklearn.metrics import f1_score -def get_entropy(probs): - return np.sum( np.multiply(probs, np.log(probs + 1e-20)) , axis=1) + +def get_entropy(probs): + return np.sum(np.multiply(probs, np.log(probs + 1e-20)), axis=1) + def get_max_conf(probs): - return np.max(probs, axis=-1) - -def find_ATC_threshold(scores, labels): + return np.max(probs, axis=-1) + + +def find_ATC_threshold(scores, labels): sorted_idx = np.argsort(scores) - + sorted_scores = scores[sorted_idx] sorted_labels = labels[sorted_idx] - - fp = np.sum(labels==0) + + fp = np.sum(labels == 0) fn = 0.0 - + min_fp_fn = np.abs(fp - fn) thres = 0.0 - for i in range(len(labels)): - if sorted_labels[i] == 0: + for i in range(len(labels)): + if sorted_labels[i] == 0: fp -= 1 - else: + else: fn += 1 - - if np.abs(fp - fn) < min_fp_fn: + + if np.abs(fp - fn) < min_fp_fn: min_fp_fn = np.abs(fp - fn) thres = sorted_scores[i] - + return min_fp_fn, thres -def get_ATC_acc(thres, scores): - return np.mean(scores>=thres) +def get_ATC_acc(thres, scores): + return np.mean(scores >= thres) + def get_ATC_f1(thres, scores, probs): preds = np.argmax(probs, axis=-1) - estim_y = abs(1 - (scores>=thres)^preds) + estim_y = np.abs(1 - (scores >= thres) ^ preds) return f1_score(estim_y, preds) - \ No newline at end of file diff --git a/conf.yaml b/conf.yaml index 6ec7356..ef29e4c 100644 --- a/conf.yaml +++ b/conf.yaml @@ -12,7 +12,30 @@ debug_conf: &debug_conf plot_confs: debug: PLOT_ESTIMATORS: - - mul_sld + - bin_sld_gs + PLOT_STDEV: true + +mc_conf: &mc_conf + global: + METRICS: + - acc + DATASET_N_PREVS: 9 + DATASET_PREVS: + - 0.4 + - 0.5 + - 0.6 + + confs: + - DATASET_NAME: rcv1 + DATASET_TARGET: CCAT + + plot_confs: + debug: + PLOT_ESTIMATORS: + - mulmc_sld + - mul_sld_gs + - bin_sld + - bin_sld_gs - atc_mc PLOT_STDEV: true @@ -29,13 +52,20 @@ test_conf: &test_conf # - DATASET_NAME: imdb plot_confs: - 2gs_vs_atc: + gs_vs_gsq: PLOT_ESTIMATORS: + - bin_sld - bin_sld_gs - - bin_sld_qgs + - bin_sld_gsq + - mul_sld + - mul_sld_gs + - mul_sld_gsq + gs_vs_atc: + PLOT_ESTIMATORS: + - bin_sld + - bin_sld_gs + - mul_sld - mul_sld_gs - - mul_sld_qgs - - ref - atc_mc - atc_ne sld_vs_pacc: @@ -44,11 +74,23 @@ test_conf: &test_conf - bin_sld_gs - mul_sld - mul_sld_gs - - ref + - bin_pacc + - bin_pacc_gs + - mul_pacc + - mul_pacc_gs + - atc_mc + - atc_ne + pacc_vs_atc: + PLOT_ESTIMATORS: + - bin_pacc + - bin_pacc_gs + - mul_pacc + - mul_pacc_gs - atc_mc - atc_ne main_conf: &main_conf + global: METRICS: - acc @@ -106,4 +148,4 @@ main_conf: &main_conf - atc_ne - doc_feat -exec: *debug_conf \ No newline at end of file +exec: *mc_conf \ No newline at end of file diff --git a/quacc.log b/quacc.log index e229551..be89552 100644 --- a/quacc.log +++ b/quacc.log @@ -1636,3 +1636,641 @@ 04/11/23 00:05:20| INFO atc_mc finished [took 26.4278s] 04/11/23 00:05:29| INFO mul_sld finished [took 35.3110s] 04/11/23 00:05:29| INFO Dataset sample 0.50 of dataset imdb_1prevs finished [took 36.4422s] +---------------------------------------------------------------------------------------------------- +04/11/23 00:19:43| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:19:49| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:19:53| WARNING Method bin_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input. +04/11/23 00:19:55| WARNING Method mul_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input. +04/11/23 00:19:57| WARNING Method bin_pacc failed. Exception: PACC.__init__() got an unexpected keyword argument 'recalib' +04/11/23 00:19:59| WARNING Method mul_pacc failed. Exception: PACC.__init__() got an unexpected keyword argument 'recalib' +04/11/23 00:20:00| WARNING Method bin_pacc_gs failed. Exception: Invalid parameter 'recalib' for estimator PACC(classifier=LogisticRegression(), n_jobs=1). Valid parameters are: ['classifier', 'n_jobs', 'val_split']. +04/11/23 00:20:01| WARNING Method mul_pacc_gs failed. Exception: Invalid parameter 'recalib' for estimator PACC(classifier=LogisticRegression(), n_jobs=1). Valid parameters are: ['classifier', 'n_jobs', 'val_split']. +---------------------------------------------------------------------------------------------------- +04/11/23 00:22:45| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:22:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:22:54| WARNING Method bin_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input. +04/11/23 00:22:55| WARNING Method mul_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input. +---------------------------------------------------------------------------------------------------- +04/11/23 00:28:11| INFO dataset rcv1_CCAT_9prevs +---------------------------------------------------------------------------------------------------- +04/11/23 00:29:39| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:29:45| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:29:49| WARNING Method bin_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input. +04/11/23 00:29:51| WARNING Method mul_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input. +04/11/23 00:30:39| WARNING Method mul_pacc_gs failed. Exception: evaluation_report() got an unexpected keyword argument 'method_name' +04/11/23 00:31:00| INFO ref finished [took 60.5788s] +04/11/23 00:31:09| INFO atc_mc finished [took 64.6156s] +04/11/23 00:31:09| INFO mul_pacc finished [took 75.1821s] +04/11/23 00:31:12| INFO atc_ne finished [took 62.8665s] +04/11/23 00:31:24| INFO mul_sld finished [took 96.8624s] +---------------------------------------------------------------------------------------------------- +04/11/23 00:33:26| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:33:31| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:33:35| WARNING Method bin_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input. +04/11/23 00:33:37| WARNING Method mul_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input. +---------------------------------------------------------------------------------------------------- +04/11/23 00:38:42| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:38:48| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:38:51| WARNING Method bin_sld_qgs failed. Exception: ExtendedCollection.extend_collection() missing 1 required positional argument: 'pred_proba' +04/11/23 00:38:52| WARNING Method mul_sld_qgs failed. Exception: ExtendedCollection.extend_collection() missing 1 required positional argument: 'pred_proba' +04/11/23 00:39:41| WARNING Method mul_pacc_gs failed. Exception: evaluation_report() got an unexpected keyword argument 'method_name' +---------------------------------------------------------------------------------------------------- +04/11/23 00:46:33| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:46:39| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:46:40| WARNING Method bin_sld failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:46:41| WARNING Method mul_sld failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:46:42| WARNING Method bin_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba' +04/11/23 00:46:43| WARNING Method mul_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba' +04/11/23 00:46:44| WARNING Method bin_sld_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:46:45| WARNING Method mul_sld_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:46:46| WARNING Method bin_pacc failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:46:47| WARNING Method mul_pacc failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:46:47| WARNING Method bin_pacc_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:46:48| WARNING Method mul_pacc_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:47:27| INFO ref finished [took 37.5294s] +04/11/23 00:47:31| INFO atc_mc finished [took 40.5777s] +04/11/23 00:47:32| INFO atc_ne finished [took 40.7565s] +04/11/23 00:47:32| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 52.8106s] +04/11/23 00:47:32| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started +04/11/23 00:47:33| WARNING Method bin_sld failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:47:34| WARNING Method mul_sld failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:47:35| WARNING Method bin_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba' +04/11/23 00:47:36| WARNING Method mul_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba' +04/11/23 00:47:37| WARNING Method bin_sld_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:47:38| WARNING Method mul_sld_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:47:39| WARNING Method bin_pacc failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:47:39| WARNING Method mul_pacc failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:47:40| WARNING Method bin_pacc_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +04/11/23 00:47:41| WARNING Method mul_pacc_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() +---------------------------------------------------------------------------------------------------- +04/11/23 00:48:05| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:48:10| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:48:13| WARNING Method bin_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba' +04/11/23 00:48:14| WARNING Method mul_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba' +---------------------------------------------------------------------------------------------------- +04/11/23 00:49:18| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:49:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:49:27| WARNING Method bin_sld_qgs failed. Exception: GridSearchQ.__init__() missing 1 required positional argument: 'model' +04/11/23 00:49:28| WARNING Method mul_sld_qgs failed. Exception: X has 47238 features, but LogisticRegression is expecting 47236 features as input. +---------------------------------------------------------------------------------------------------- +04/11/23 00:51:27| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:51:32| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:51:36| WARNING Method bin_sld_qgs failed. Exception: X has 47238 features, but LogisticRegression is expecting 47236 features as input. +04/11/23 00:51:37| WARNING Method mul_sld_qgs failed. Exception: X has 47238 features, but LogisticRegression is expecting 47236 features as input. +---------------------------------------------------------------------------------------------------- +04/11/23 00:54:47| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:54:53| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 00:54:57| WARNING Method bin_sld_qgs failed. Exception: a must be greater than 0 unless no samples are taken +04/11/23 00:54:58| WARNING Method mul_sld_qgs failed. Exception: a must be greater than 0 unless no samples are taken +---------------------------------------------------------------------------------------------------- +04/11/23 00:58:47| INFO dataset rcv1_CCAT_9prevs +04/11/23 00:58:52| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 01:00:04| INFO ref finished [took 61.6328s] +04/11/23 01:00:11| INFO atc_mc finished [took 65.4916s] +04/11/23 01:00:13| INFO atc_ne finished [took 63.2288s] +04/11/23 01:00:14| INFO mul_pacc finished [took 75.5101s] +04/11/23 01:00:30| INFO mul_sld finished [took 96.6656s] +04/11/23 01:00:41| INFO mul_pacc_gs finished [took 99.7211s] +04/11/23 01:03:02| INFO bin_pacc finished [took 244.6260s] +04/11/23 01:03:07| INFO bin_sld finished [took 254.3478s] +04/11/23 01:04:51| INFO mul_sld_gs finished [took 354.7477s] +04/11/23 01:05:02| INFO bin_pacc_gs finished [took 362.1808s] +04/11/23 01:09:24| INFO bin_sld_gs finished [took 628.6714s] +04/11/23 01:09:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 631.8421s] +04/11/23 01:09:24| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started +04/11/23 01:10:39| INFO ref finished [took 63.5158s] +04/11/23 01:10:44| INFO atc_mc finished [took 66.4279s] +04/11/23 01:10:46| INFO mul_pacc finished [took 75.3281s] +04/11/23 01:10:47| INFO atc_ne finished [took 67.5374s] +04/11/23 01:10:52| INFO mul_sld finished [took 86.6592s] +04/11/23 01:11:19| INFO mul_pacc_gs finished [took 104.6374s] +04/11/23 01:13:58| INFO bin_sld finished [took 273.4932s] +04/11/23 01:14:01| INFO bin_pacc finished [took 271.3481s] +04/11/23 01:15:42| INFO mul_sld_gs finished [took 374.2416s] +04/11/23 01:16:01| INFO bin_pacc_gs finished [took 388.0839s] +04/11/23 01:20:29| INFO bin_sld_gs finished [took 661.9729s] +04/11/23 01:20:29| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 665.2874s] +04/11/23 01:20:29| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started +04/11/23 01:21:46| INFO ref finished [took 63.8544s] +04/11/23 01:21:50| INFO atc_mc finished [took 66.6917s] +04/11/23 01:21:52| INFO atc_ne finished [took 65.0860s] +04/11/23 01:21:53| INFO mul_pacc finished [took 77.2630s] +04/11/23 01:21:55| INFO mul_sld finished [took 83.3146s] +04/11/23 01:22:23| INFO mul_pacc_gs finished [took 102.3761s] +04/11/23 01:24:47| INFO bin_pacc finished [took 252.0964s] +04/11/23 01:24:49| INFO bin_sld finished [took 258.6998s] +04/11/23 01:26:37| INFO mul_sld_gs finished [took 363.7500s] +04/11/23 01:26:49| INFO bin_pacc_gs finished [took 370.5817s] +04/11/23 01:31:27| INFO bin_sld_gs finished [took 654.3921s] +04/11/23 01:31:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 658.0041s] +04/11/23 01:31:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started +04/11/23 01:32:33| INFO ref finished [took 55.7749s] +04/11/23 01:32:38| INFO atc_mc finished [took 59.4190s] +04/11/23 01:32:40| INFO atc_ne finished [took 59.5155s] +04/11/23 01:32:42| INFO mul_pacc finished [took 68.8994s] +04/11/23 01:32:44| INFO mul_sld finished [took 74.6470s] +04/11/23 01:33:09| INFO mul_pacc_gs finished [took 92.6473s] +04/11/23 01:35:32| INFO bin_pacc finished [took 239.7541s] +04/11/23 01:35:34| INFO bin_sld finished [took 245.7504s] +04/11/23 01:37:19| INFO mul_sld_gs finished [took 348.1188s] +04/11/23 01:37:30| INFO bin_pacc_gs finished [took 355.4729s] +04/11/23 01:42:07| INFO bin_sld_gs finished [took 636.8598s] +04/11/23 01:42:07| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 639.9201s] +04/11/23 01:42:07| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started +04/11/23 01:43:14| INFO ref finished [took 56.1531s] +04/11/23 01:43:19| INFO atc_mc finished [took 59.7473s] +04/11/23 01:43:20| INFO atc_ne finished [took 59.0606s] +04/11/23 01:43:23| INFO mul_pacc finished [took 69.4266s] +04/11/23 01:43:25| INFO mul_sld finished [took 76.3328s] +04/11/23 01:43:49| INFO mul_pacc_gs finished [took 92.3926s] +04/11/23 01:46:05| INFO bin_pacc finished [took 233.1877s] +04/11/23 01:46:08| INFO bin_sld finished [took 239.8757s] +04/11/23 01:47:51| INFO mul_sld_gs finished [took 339.5911s] +04/11/23 01:48:00| INFO bin_pacc_gs finished [took 345.7788s] +04/11/23 01:52:44| INFO bin_sld_gs finished [took 633.8407s] +04/11/23 01:52:44| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 637.0648s] +04/11/23 01:52:44| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started +04/11/23 01:53:52| INFO ref finished [took 57.4958s] +04/11/23 01:53:57| INFO atc_mc finished [took 60.9998s] +04/11/23 01:53:58| INFO atc_ne finished [took 60.4847s] +04/11/23 01:54:01| INFO mul_pacc finished [took 70.5216s] +04/11/23 01:54:04| INFO mul_sld finished [took 78.2910s] +04/11/23 01:54:27| INFO mul_pacc_gs finished [took 94.4726s] +04/11/23 01:56:48| INFO bin_pacc finished [took 238.5969s] +04/11/23 01:56:50| INFO bin_sld finished [took 244.5679s] +04/11/23 01:58:31| INFO mul_sld_gs finished [took 342.4843s] +04/11/23 01:58:44| INFO bin_pacc_gs finished [took 352.8264s] +04/11/23 02:03:32| INFO bin_sld_gs finished [took 644.7046s] +04/11/23 02:03:32| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 647.8055s] +04/11/23 02:03:32| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started +04/11/23 02:04:37| INFO ref finished [took 55.4488s] +04/11/23 02:04:42| INFO atc_mc finished [took 59.2634s] +04/11/23 02:04:44| INFO atc_ne finished [took 59.1371s] +04/11/23 02:04:46| INFO mul_pacc finished [took 68.0960s] +04/11/23 02:04:50| INFO mul_sld finished [took 76.4282s] +04/11/23 02:05:12| INFO mul_pacc_gs finished [took 91.7735s] +04/11/23 02:07:30| INFO bin_pacc finished [took 232.7650s] +04/11/23 02:07:36| INFO bin_sld finished [took 242.4077s] +04/11/23 02:09:14| INFO mul_sld_gs finished [took 338.1418s] +04/11/23 02:09:26| INFO bin_pacc_gs finished [took 347.2033s] +04/11/23 02:13:59| INFO bin_sld_gs finished [took 624.6098s] +04/11/23 02:13:59| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 627.7979s] +04/11/23 02:13:59| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started +04/11/23 02:15:05| INFO ref finished [took 55.1962s] +04/11/23 02:15:10| INFO atc_mc finished [took 59.0907s] +04/11/23 02:15:11| INFO atc_ne finished [took 59.1531s] +04/11/23 02:15:13| INFO mul_pacc finished [took 67.6705s] +04/11/23 02:15:17| INFO mul_sld finished [took 75.4559s] +04/11/23 02:15:41| INFO mul_pacc_gs finished [took 92.4901s] +04/11/23 02:17:59| INFO bin_pacc finished [took 233.8600s] +04/11/23 02:18:04| INFO bin_sld finished [took 243.2382s] +04/11/23 02:19:40| INFO mul_sld_gs finished [took 336.0961s] +04/11/23 02:19:51| INFO bin_pacc_gs finished [took 344.4075s] +04/11/23 02:24:30| INFO bin_sld_gs finished [took 627.6209s] +04/11/23 02:24:30| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 630.8251s] +04/11/23 02:24:30| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started +04/11/23 02:25:35| INFO ref finished [took 54.8513s] +04/11/23 02:25:40| INFO atc_mc finished [took 58.8528s] +04/11/23 02:25:41| INFO atc_ne finished [took 58.6035s] +04/11/23 02:25:43| INFO mul_pacc finished [took 66.9030s] +04/11/23 02:25:57| INFO mul_sld finished [took 84.2072s] +04/11/23 02:26:10| INFO mul_pacc_gs finished [took 91.0973s] +04/11/23 02:28:31| INFO bin_pacc finished [took 235.7331s] +04/11/23 02:28:35| INFO bin_sld finished [took 243.6260s] +04/11/23 02:30:09| INFO mul_sld_gs finished [took 334.4842s] +04/11/23 02:30:22| INFO bin_pacc_gs finished [took 344.6874s] +04/11/23 02:34:46| INFO bin_sld_gs finished [took 612.1219s] +04/11/23 02:34:46| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 615.2004s] +---------------------------------------------------------------------------------------------------- +04/11/23 02:57:35| INFO dataset rcv1_CCAT_9prevs +04/11/23 02:57:39| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 02:57:47| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 02:58:59| INFO ref finished [took 64.5948s] +04/11/23 02:59:06| INFO atc_mc finished [took 69.5808s] +04/11/23 02:59:12| INFO mul_pacc finished [took 82.8518s] +04/11/23 02:59:13| INFO atc_ne finished [took 72.1303s] +04/11/23 02:59:26| INFO mul_sld finished [took 103.4201s] +04/11/23 02:59:30| WARNING Method bin_sld_gsq failed. Exception: This solver needs samples of at least 2 classes in the data, but the data contains only one class: 1 +04/11/23 02:59:41| INFO mul_pacc_gs finished [took 109.4672s] +04/11/23 03:01:59| INFO bin_pacc finished [took 251.3945s] +04/11/23 03:02:02| INFO bin_sld finished [took 260.0226s] +04/11/23 03:03:35| INFO mul_sld_gs finished [took 350.1705s] +04/11/23 03:03:48| INFO bin_pacc_gs finished [took 357.9668s] +04/11/23 03:07:59| INFO bin_sld_gs finished [took 615.8087s] +04/11/23 03:07:59| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 620.4985s] +04/11/23 03:07:59| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started +04/11/23 03:08:06| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 03:09:17| INFO ref finished [took 64.4692s] +04/11/23 03:09:25| INFO atc_mc finished [took 71.3766s] +04/11/23 03:09:27| INFO atc_ne finished [took 71.0947s] +04/11/23 03:09:28| INFO mul_pacc finished [took 80.0201s] +04/11/23 03:09:31| INFO mul_sld finished [took 89.4295s] +04/11/23 03:09:55| INFO mul_pacc_gs finished [took 104.7292s] +04/11/23 03:12:25| INFO bin_sld finished [took 263.6824s] +04/11/23 03:12:25| INFO bin_pacc finished [took 258.6502s] +04/11/23 03:14:01| INFO mul_sld_gs finished [took 357.3344s] +04/11/23 03:14:14| INFO bin_sld_gsq finished [took 369.1636s] +04/11/23 03:14:22| INFO bin_pacc_gs finished [took 372.8646s] +04/11/23 03:18:40| INFO bin_sld_gs finished [took 636.9190s] +04/11/23 03:18:40| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 640.2322s] +04/11/23 03:18:40| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started +04/11/23 03:18:46| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 03:19:58| INFO ref finished [took 65.9462s] +04/11/23 03:20:02| INFO atc_mc finished [took 68.5710s] +04/11/23 03:20:04| INFO atc_ne finished [took 68.9466s] +04/11/23 03:20:06| INFO mul_pacc finished [took 77.9039s] +04/11/23 03:20:06| INFO mul_sld finished [took 84.0917s] +04/11/23 03:20:37| INFO mul_pacc_gs finished [took 106.2536s] +04/11/23 03:23:04| INFO bin_pacc finished [took 257.4211s] +04/11/23 03:23:05| INFO bin_sld finished [took 264.3442s] +04/11/23 03:24:49| INFO mul_sld_gs finished [took 365.1691s] +04/11/23 03:25:01| INFO bin_pacc_gs finished [took 371.9184s] +04/11/23 03:25:02| INFO bin_sld_gsq finished [took 377.0442s] +04/11/23 03:29:37| INFO bin_sld_gs finished [took 654.0366s] +04/11/23 03:29:37| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 657.0840s] +04/11/23 03:29:37| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started +04/11/23 03:29:42| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 03:30:51| INFO ref finished [took 62.7217s] +04/11/23 03:30:58| INFO atc_mc finished [took 67.8613s] +04/11/23 03:31:00| INFO atc_ne finished [took 68.5026s] +04/11/23 03:31:03| INFO mul_sld finished [took 83.8857s] +04/11/23 03:31:03| INFO mul_pacc finished [took 78.6340s] +04/11/23 03:31:30| INFO mul_pacc_gs finished [took 103.4683s] +04/11/23 03:34:00| INFO bin_sld finished [took 262.4457s] +04/11/23 03:34:02| INFO bin_pacc finished [took 258.2247s] +04/11/23 03:35:44| INFO mul_sld_gs finished [took 363.8135s] +04/11/23 03:35:58| INFO bin_pacc_gs finished [took 372.0485s] +04/11/23 03:36:05| INFO bin_sld_gsq finished [took 382.9585s] +04/11/23 03:40:39| INFO bin_sld_gs finished [took 659.6222s] +04/11/23 03:40:39| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 662.5763s] +04/11/23 03:40:39| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started +04/11/23 03:40:45| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 03:41:56| INFO ref finished [took 64.5923s] +04/11/23 03:42:01| INFO atc_mc finished [took 68.0148s] +04/11/23 03:42:03| INFO atc_ne finished [took 68.3119s] +04/11/23 03:42:04| INFO mul_pacc finished [took 76.9397s] +04/11/23 03:42:07| INFO mul_sld finished [took 85.5363s] +04/11/23 03:42:34| INFO mul_pacc_gs finished [took 103.4448s] +04/11/23 03:45:01| INFO bin_sld finished [took 260.0814s] +04/11/23 03:45:03| INFO bin_pacc finished [took 256.9386s] +04/11/23 03:46:45| INFO mul_sld_gs finished [took 361.5910s] +04/11/23 03:47:01| INFO bin_pacc_gs finished [took 371.9657s] +04/11/23 03:47:13| INFO bin_sld_gsq finished [took 388.2498s] +04/11/23 03:51:40| INFO bin_sld_gs finished [took 657.4008s] +04/11/23 03:51:40| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 660.5115s] +04/11/23 03:51:40| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started +04/11/23 03:51:46| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 03:52:54| INFO ref finished [took 61.9225s] +04/11/23 03:53:00| INFO atc_mc finished [took 66.3156s] +04/11/23 03:53:02| INFO atc_ne finished [took 66.5025s] +04/11/23 03:53:04| INFO mul_pacc finished [took 75.8808s] +04/11/23 03:53:06| INFO mul_sld finished [took 84.3204s] +04/11/23 03:53:33| INFO mul_pacc_gs finished [took 102.5763s] +04/11/23 03:56:04| INFO bin_sld finished [took 263.2781s] +04/11/23 03:56:04| INFO bin_pacc finished [took 257.7298s] +04/11/23 03:57:44| INFO mul_sld_gs finished [took 359.7910s] +04/11/23 03:58:00| INFO bin_pacc_gs finished [took 371.3848s] +04/11/23 03:58:11| INFO bin_sld_gsq finished [took 386.0904s] +04/11/23 04:02:50| INFO bin_sld_gs finished [took 667.6623s] +04/11/23 04:02:50| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 670.7255s] +04/11/23 04:02:50| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started +04/11/23 04:02:57| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 04:04:05| INFO ref finished [took 62.3256s] +04/11/23 04:04:13| INFO atc_mc finished [took 68.9525s] +04/11/23 04:04:15| INFO atc_ne finished [took 68.8750s] +04/11/23 04:04:16| INFO mul_pacc finished [took 77.5049s] +04/11/23 04:04:19| INFO mul_sld finished [took 86.0694s] +04/11/23 04:04:45| INFO mul_pacc_gs finished [took 103.3513s] +04/11/23 04:07:15| INFO bin_pacc finished [took 257.6456s] +04/11/23 04:07:16| INFO bin_sld finished [took 263.9914s] +04/11/23 04:08:55| INFO mul_sld_gs finished [took 360.5634s] +04/11/23 04:09:12| INFO bin_pacc_gs finished [took 372.2665s] +04/11/23 04:09:18| INFO bin_sld_gsq finished [took 381.8311s] +04/11/23 04:13:39| INFO bin_sld_gs finished [took 645.3599s] +04/11/23 04:13:39| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 648.5328s] +04/11/23 04:13:39| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started +04/11/23 04:13:45| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 04:14:51| INFO ref finished [took 59.8110s] +04/11/23 04:14:58| INFO atc_mc finished [took 65.2666s] +04/11/23 04:14:59| INFO atc_ne finished [took 64.5173s] +04/11/23 04:15:01| INFO mul_pacc finished [took 73.8332s] +04/11/23 04:15:04| INFO mul_sld finished [took 82.3509s] +04/11/23 04:15:29| INFO mul_pacc_gs finished [took 99.3541s] +04/11/23 04:18:00| INFO bin_pacc finished [took 254.3308s] +04/11/23 04:18:03| INFO bin_sld finished [took 262.3008s] +04/11/23 04:19:40| INFO mul_sld_gs finished [took 357.1229s] +04/11/23 04:19:57| INFO bin_pacc_gs finished [took 368.4516s] +04/11/23 04:20:03| INFO bin_sld_gsq finished [took 378.7658s] +04/11/23 04:24:37| INFO bin_sld_gs finished [took 655.1931s] +04/11/23 04:24:37| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 658.3505s] +04/11/23 04:24:37| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started +04/11/23 04:24:43| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid' +04/11/23 04:25:49| INFO ref finished [took 59.4546s] +04/11/23 04:25:55| INFO atc_mc finished [took 63.5805s] +04/11/23 04:25:58| INFO atc_ne finished [took 63.2985s] +04/11/23 04:25:58| INFO mul_pacc finished [took 72.5198s] +04/11/23 04:26:11| INFO mul_sld finished [took 91.7136s] +04/11/23 04:26:27| INFO mul_pacc_gs finished [took 98.8722s] +04/11/23 04:28:57| INFO bin_pacc finished [took 252.8144s] +04/11/23 04:29:02| INFO bin_sld finished [took 263.8013s] +04/11/23 04:30:35| INFO mul_sld_gs finished [took 353.3693s] +04/11/23 04:30:51| INFO bin_sld_gsq finished [took 368.8564s] +04/11/23 04:30:54| INFO bin_pacc_gs finished [took 367.5592s] +04/11/23 04:35:11| INFO bin_sld_gs finished [took 630.6700s] +04/11/23 04:35:11| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 633.7494s] +---------------------------------------------------------------------------------------------------- +04/11/23 19:09:42| INFO dataset rcv1_CCAT_9prevs +04/11/23 19:09:47| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 19:10:28| INFO ref finished [took 36.0351s] +04/11/23 19:10:32| INFO atc_mc finished [took 38.9507s] +04/11/23 19:10:35| INFO mulmc_sld finished [took 43.7869s] +04/11/23 19:10:50| INFO mul_sld finished [took 60.8007s] +04/11/23 19:10:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 62.9600s] +04/11/23 19:10:50| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started +04/11/23 19:11:29| INFO ref finished [took 36.3632s] +04/11/23 19:11:34| INFO atc_mc finished [took 39.5928s] +04/11/23 19:11:36| INFO mulmc_sld finished [took 44.2915s] +04/11/23 19:11:44| INFO mul_sld finished [took 52.6727s] +04/11/23 19:11:44| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 54.0362s] +04/11/23 19:11:44| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started +04/11/23 19:12:24| INFO ref finished [took 36.4303s] +04/11/23 19:12:27| INFO atc_mc finished [took 39.2329s] +04/11/23 19:12:30| INFO mulmc_sld finished [took 43.6247s] +04/11/23 19:12:36| INFO mul_sld finished [took 50.2041s] +04/11/23 19:12:36| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 51.6412s] +04/11/23 19:12:36| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started +04/11/23 19:13:16| INFO ref finished [took 36.7551s] +04/11/23 19:13:19| INFO atc_mc finished [took 39.2806s] +04/11/23 19:13:21| INFO mulmc_sld finished [took 43.6120s] +04/11/23 19:13:27| INFO mul_sld finished [took 50.4446s] +04/11/23 19:13:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 51.6672s] +04/11/23 19:13:27| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started +04/11/23 19:14:07| INFO ref finished [took 35.8789s] +04/11/23 19:14:11| INFO atc_mc finished [took 39.2168s] +04/11/23 19:14:13| INFO mulmc_sld finished [took 43.4580s] +04/11/23 19:14:20| INFO mul_sld finished [took 51.2902s] +04/11/23 19:14:20| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 52.6303s] +04/11/23 19:14:20| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started +04/11/23 19:15:00| INFO ref finished [took 36.3735s] +04/11/23 19:15:04| INFO atc_mc finished [took 39.7035s] +04/11/23 19:15:06| INFO mulmc_sld finished [took 43.6364s] +04/11/23 19:15:13| INFO mul_sld finished [took 52.0138s] +04/11/23 19:15:13| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 53.3303s] +04/11/23 19:15:13| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started +04/11/23 19:15:54| INFO ref finished [took 37.3366s] +04/11/23 19:15:57| INFO atc_mc finished [took 39.8921s] +04/11/23 19:16:00| INFO mulmc_sld finished [took 44.5159s] +04/11/23 19:16:08| INFO mul_sld finished [took 53.0806s] +04/11/23 19:16:08| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 54.4117s] +04/11/23 19:16:08| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started +04/11/23 19:16:47| INFO ref finished [took 35.7800s] +04/11/23 19:16:50| INFO atc_mc finished [took 38.4484s] +04/11/23 19:16:53| INFO mulmc_sld finished [took 42.7405s] +04/11/23 19:17:01| INFO mul_sld finished [took 51.5556s] +04/11/23 19:17:01| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 52.9684s] +04/11/23 19:17:01| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started +04/11/23 19:17:39| INFO ref finished [took 35.0919s] +04/11/23 19:17:43| INFO atc_mc finished [took 38.1718s] +04/11/23 19:17:45| INFO mulmc_sld finished [took 42.4413s] +04/11/23 19:17:59| INFO mul_sld finished [took 57.0766s] +04/11/23 19:17:59| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 58.3668s] +---------------------------------------------------------------------------------------------------- +04/11/23 19:42:38| INFO dataset rcv1_CCAT_9prevs +04/11/23 19:42:43| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 19:43:27| INFO ref finished [took 38.7664s] +04/11/23 19:43:31| INFO atc_mc finished [took 42.4000s] +04/11/23 19:43:33| INFO mulmc_sld finished [took 47.0913s] +04/11/23 19:43:34| INFO binmc_sld finished [took 47.1675s] +04/11/23 19:43:49| INFO mul_sld finished [took 64.1382s] +04/11/23 19:46:00| INFO bin_sld finished [took 195.9822s] +04/11/23 19:46:00| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 197.2916s] +04/11/23 19:46:00| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started +04/11/23 19:46:44| INFO ref finished [took 38.5976s] +04/11/23 19:46:48| INFO atc_mc finished [took 41.9465s] +04/11/23 19:46:49| INFO mulmc_sld finished [took 46.2205s] +04/11/23 19:46:51| INFO binmc_sld finished [took 46.7475s] +04/11/23 19:46:58| INFO mul_sld finished [took 56.3552s] +04/11/23 19:49:14| INFO bin_sld finished [took 193.2923s] +04/11/23 19:49:14| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 194.6251s] +04/11/23 19:49:14| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started +04/11/23 19:49:58| INFO ref finished [took 38.3754s] +04/11/23 19:50:02| INFO atc_mc finished [took 41.0091s] +04/11/23 19:50:03| INFO mulmc_sld finished [took 45.6205s] +04/11/23 19:50:05| INFO binmc_sld finished [took 46.1852s] +04/11/23 19:50:10| INFO mul_sld finished [took 52.9704s] +04/11/23 19:52:27| INFO bin_sld finished [took 190.6101s] +04/11/23 19:52:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 192.0378s] +04/11/23 19:52:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started +04/11/23 19:53:10| INFO ref finished [took 38.4467s] +04/11/23 19:53:13| INFO atc_mc finished [took 41.2602s] +04/11/23 19:53:15| INFO mulmc_sld finished [took 45.7496s] +04/11/23 19:53:16| INFO binmc_sld finished [took 45.5531s] +04/11/23 19:53:21| INFO mul_sld finished [took 52.5067s] +04/11/23 19:55:38| INFO bin_sld finished [took 190.7744s] +04/11/23 19:55:38| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 191.9715s] +04/11/23 19:55:39| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started +04/11/23 19:56:21| INFO ref finished [took 37.9420s] +04/11/23 19:56:26| INFO atc_mc finished [took 41.2056s] +04/11/23 19:56:27| INFO mulmc_sld finished [took 45.7577s] +04/11/23 19:56:28| INFO binmc_sld finished [took 45.6411s] +04/11/23 19:56:34| INFO mul_sld finished [took 53.5219s] +04/11/23 19:58:51| INFO bin_sld finished [took 191.1772s] +04/11/23 19:58:51| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 192.4566s] +04/11/23 19:58:51| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started +04/11/23 19:59:34| INFO ref finished [took 37.8604s] +04/11/23 19:59:38| INFO atc_mc finished [took 41.0334s] +04/11/23 19:59:39| INFO mulmc_sld finished [took 45.1999s] +04/11/23 19:59:40| INFO binmc_sld finished [took 45.4846s] +04/11/23 19:59:47| INFO mul_sld finished [took 54.3166s] +04/11/23 20:02:04| INFO bin_sld finished [took 191.4002s] +04/11/23 20:02:04| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 192.6275s] +04/11/23 20:02:04| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started +04/11/23 20:02:48| INFO ref finished [took 38.8313s] +04/11/23 20:02:52| INFO atc_mc finished [took 42.1162s] +04/11/23 20:02:54| INFO mulmc_sld finished [took 47.0413s] +04/11/23 20:02:55| INFO binmc_sld finished [took 46.8891s] +04/11/23 20:03:02| INFO mul_sld finished [took 55.8821s] +04/11/23 20:05:19| INFO bin_sld finished [took 193.7571s] +04/11/23 20:05:19| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 195.2404s] +04/11/23 20:05:19| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started +04/11/23 20:06:03| INFO ref finished [took 38.7982s] +04/11/23 20:06:06| INFO atc_mc finished [took 41.6213s] +04/11/23 20:06:08| INFO mulmc_sld finished [took 46.2646s] +04/11/23 20:06:09| INFO binmc_sld finished [took 46.2453s] +04/11/23 20:06:16| INFO mul_sld finished [took 54.8621s] +04/11/23 20:08:35| INFO bin_sld finished [took 194.5226s] +04/11/23 20:08:35| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 195.9251s] +04/11/23 20:08:35| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started +04/11/23 20:09:18| INFO ref finished [took 38.3873s] +04/11/23 20:09:22| INFO atc_mc finished [took 41.2537s] +04/11/23 20:09:24| INFO mulmc_sld finished [took 46.2211s] +04/11/23 20:09:25| INFO binmc_sld finished [took 46.6421s] +04/11/23 20:09:38| INFO mul_sld finished [took 60.9539s] +04/11/23 20:11:51| INFO bin_sld finished [took 195.1888s] +04/11/23 20:11:51| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 196.4776s] +---------------------------------------------------------------------------------------------------- +04/11/23 20:56:32| INFO dataset rcv1_CCAT_9prevs +04/11/23 20:56:37| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started +04/11/23 20:57:33| INFO ref finished [took 49.2697s] +04/11/23 20:57:38| INFO atc_mc finished [took 53.2068s] +04/11/23 20:57:39| INFO mulmc_sld finished [took 58.6224s] +04/11/23 20:58:59| INFO mulmc_sld_gs finished [took 136.0930s] +04/11/23 21:00:30| INFO binmc_sld finished [took 230.3290s] +04/11/23 21:02:12| INFO mul_sld_gs finished [took 333.4899s] +04/11/23 21:06:49| INFO bin_sld_gs finished [took 610.5751s] +04/11/23 21:06:54| INFO binmc_sld_gs finished [took 612.8900s] +04/11/23 21:06:55| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 617.6873s] +04/11/23 21:06:55| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started +04/11/23 21:07:52| INFO ref finished [took 49.8077s] +04/11/23 21:07:56| INFO atc_mc finished [took 53.3303s] +04/11/23 21:07:57| INFO mulmc_sld finished [took 58.9345s] +04/11/23 21:09:17| INFO mulmc_sld_gs finished [took 136.5258s] +04/11/23 21:10:51| INFO binmc_sld finished [took 233.4049s] +04/11/23 21:12:35| INFO mul_sld_gs finished [took 338.2751s] +04/11/23 21:17:38| INFO bin_sld_gs finished [took 641.8524s] +04/11/23 21:18:19| INFO binmc_sld_gs finished [took 679.9471s] +04/11/23 21:18:19| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 684.7098s] +04/11/23 21:18:19| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started +04/11/23 21:19:24| INFO ref finished [took 55.3767s] +04/11/23 21:19:28| INFO mulmc_sld finished [took 64.2789s] +04/11/23 21:19:29| INFO atc_mc finished [took 59.5610s] +04/11/23 21:20:57| INFO mulmc_sld_gs finished [took 150.1392s] +04/11/23 21:22:36| INFO binmc_sld finished [took 253.0960s] +04/11/23 21:24:16| INFO mul_sld_gs finished [took 354.6283s] +04/11/23 21:29:15| INFO bin_sld_gs finished [took 654.3325s] +04/11/23 21:29:50| INFO binmc_sld_gs finished [took 684.5074s] +04/11/23 21:29:50| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 690.4897s] +04/11/23 21:29:50| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started +04/11/23 21:30:45| INFO ref finished [took 48.2647s] +04/11/23 21:30:51| INFO atc_mc finished [took 52.2724s] +04/11/23 21:30:51| INFO mulmc_sld finished [took 57.5142s] +04/11/23 21:32:07| INFO mulmc_sld_gs finished [took 131.4908s] +04/11/23 21:33:38| INFO binmc_sld finished [took 224.9620s] +04/11/23 21:35:22| INFO mul_sld_gs finished [took 329.9053s] +04/11/23 21:40:25| INFO bin_sld_gs finished [took 634.4342s] +04/11/23 21:41:08| INFO binmc_sld_gs finished [took 673.6071s] +04/11/23 21:41:08| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 678.4725s] +04/11/23 21:41:08| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started +04/11/23 21:42:03| INFO ref finished [took 47.4381s] +04/11/23 21:42:08| INFO atc_mc finished [took 51.3566s] +04/11/23 21:42:09| INFO mulmc_sld finished [took 56.6180s] +04/11/23 21:43:23| INFO mulmc_sld_gs finished [took 128.6413s] +04/11/23 21:44:54| INFO binmc_sld finished [took 222.7951s] +04/11/23 21:46:39| INFO mul_sld_gs finished [took 328.8118s] +04/11/23 21:51:37| INFO bin_sld_gs finished [took 627.4937s] +04/11/23 21:52:17| INFO binmc_sld_gs finished [took 663.8116s] +04/11/23 21:52:17| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 668.8948s] +04/11/23 21:52:17| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started +04/11/23 21:53:12| INFO ref finished [took 47.6269s] +04/11/23 21:53:16| INFO atc_mc finished [took 51.1109s] +04/11/23 21:53:17| INFO mulmc_sld finished [took 56.5728s] +04/11/23 21:54:31| INFO mulmc_sld_gs finished [took 128.0358s] +04/11/23 21:56:00| INFO binmc_sld finished [took 220.0811s] +04/11/23 21:57:46| INFO mul_sld_gs finished [took 327.0856s] +04/11/23 22:02:58| INFO bin_sld_gs finished [took 639.3432s] +04/11/23 22:03:48| INFO binmc_sld_gs finished [took 686.2326s] +04/11/23 22:03:48| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 690.9677s] +04/11/23 22:03:48| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started +04/11/23 22:04:42| INFO ref finished [took 47.2804s] +04/11/23 22:04:48| INFO atc_mc finished [took 51.6888s] +04/11/23 22:04:48| INFO mulmc_sld finished [took 56.1465s] +04/11/23 22:06:06| INFO mulmc_sld_gs finished [took 132.4278s] +04/11/23 22:07:33| INFO binmc_sld finished [took 221.9299s] +04/11/23 22:09:19| INFO mul_sld_gs finished [took 329.1446s] +04/11/23 22:14:09| INFO bin_sld_gs finished [took 619.3584s] +04/11/23 22:14:32| INFO binmc_sld_gs finished [took 638.7326s] +04/11/23 22:14:32| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 643.6278s] +04/11/23 22:14:32| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started +04/11/23 22:15:26| INFO ref finished [took 47.3139s] +04/11/23 22:15:30| INFO atc_mc finished [took 50.8602s] +04/11/23 22:15:32| INFO mulmc_sld finished [took 56.5107s] +04/11/23 22:16:47| INFO mulmc_sld_gs finished [took 129.5292s] +04/11/23 22:18:22| INFO binmc_sld finished [took 226.9238s] +04/11/23 22:20:02| INFO mul_sld_gs finished [took 327.7014s] +04/11/23 22:24:57| INFO bin_sld_gs finished [took 624.4254s] +04/11/23 22:25:13| INFO binmc_sld_gs finished [took 636.2675s] +04/11/23 22:25:13| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 641.0382s] +04/11/23 22:25:13| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started +04/11/23 22:26:07| INFO ref finished [took 47.3224s] +04/11/23 22:26:12| INFO atc_mc finished [took 51.1828s] +04/11/23 22:26:13| INFO mulmc_sld finished [took 56.6133s] +04/11/23 22:27:30| INFO mulmc_sld_gs finished [took 131.3662s] +04/11/23 22:29:05| INFO binmc_sld finished [took 229.3002s] +04/11/23 22:30:38| INFO mul_sld_gs finished [took 323.5271s] +04/11/23 22:35:21| INFO bin_sld_gs finished [took 606.6430s] +04/11/23 22:35:30| INFO binmc_sld_gs finished [took 612.5966s] +04/11/23 22:35:30| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 617.3109s] +---------------------------------------------------------------------------------------------------- +04/11/23 22:49:37| ERROR Evaluation over rcv1_CCAT_3prevs failed. Exception: 'Invalid estimator: estimator binmc_sld_gs does not exist' +04/11/23 22:49:37| ERROR Failed while saving configuration rcv1_CCAT_debug of rcv1_CCAT_3prevs. Exception: cannot access local variable 'dr' where it is not associated with a value +---------------------------------------------------------------------------------------------------- +04/11/23 22:50:07| INFO dataset rcv1_CCAT_3prevs +04/11/23 22:50:12| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started +---------------------------------------------------------------------------------------------------- +04/11/23 22:55:55| INFO dataset rcv1_CCAT_3prevs +04/11/23 22:55:59| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started +04/11/23 22:56:48| INFO ref finished [took 44.4275s] +---------------------------------------------------------------------------------------------------- +04/11/23 22:56:59| INFO dataset rcv1_CCAT_3prevs +04/11/23 22:57:03| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started +04/11/23 22:57:09| WARNING Method mul_sld_gs failed. Exception: '>=' not supported between instances of 'TypeError' and 'int' +04/11/23 22:57:17| WARNING Method bin_sld_gs failed. Exception: '>=' not supported between instances of 'TypeError' and 'int' +---------------------------------------------------------------------------------------------------- +04/11/23 22:58:04| INFO dataset rcv1_CCAT_3prevs +04/11/23 22:58:09| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started +04/11/23 22:58:58| INFO ref finished [took 43.7541s] +04/11/23 22:59:05| INFO atc_mc finished [took 50.0628s] +---------------------------------------------------------------------------------------------------- +04/11/23 23:01:22| INFO dataset rcv1_CCAT_3prevs +04/11/23 23:01:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started +04/11/23 23:02:16| INFO ref finished [took 43.9765s] +04/11/23 23:02:23| INFO atc_mc finished [took 50.5568s] +---------------------------------------------------------------------------------------------------- +04/11/23 23:09:33| INFO dataset rcv1_CCAT_3prevs +04/11/23 23:09:37| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started +04/11/23 23:09:38| WARNING Method binmc_sld failed. Exception: classifier and pred_proba cannot be both None +04/11/23 23:09:39| WARNING Method mulmc_sld failed. Exception: classifier and pred_proba cannot be both None +04/11/23 23:09:40| WARNING Method bin_sld_gs failed. Exception: no combination of hyperparameters seem to work +04/11/23 23:09:41| WARNING Method mul_sld_gs failed. Exception: no combination of hyperparameters seem to work +---------------------------------------------------------------------------------------------------- +04/11/23 23:10:23| INFO dataset rcv1_CCAT_3prevs +04/11/23 23:10:28| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started +04/11/23 23:11:15| INFO ref finished [took 42.4887s] +04/11/23 23:11:20| INFO atc_mc finished [took 45.6262s] +04/11/23 23:11:21| INFO mulmc_sld finished [took 50.9790s] +04/11/23 23:13:57| INFO binmc_sld finished [took 208.3159s] +---------------------------------------------------------------------------------------------------- +04/11/23 23:16:22| INFO dataset rcv1_CCAT_3prevs +04/11/23 23:16:26| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started +04/11/23 23:17:12| INFO ref finished [took 40.5978s] +04/11/23 23:17:16| INFO atc_mc finished [took 43.6933s] +04/11/23 23:17:17| INFO mulmc_sld finished [took 49.0808s] +04/11/23 23:19:53| INFO binmc_sld finished [took 205.5731s] +04/11/23 23:22:24| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00672) [took 354.1411s] +04/11/23 23:23:05| INFO mul_sld_gs finished [took 394.8240s] +04/11/23 23:30:41| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': 'vs', 'confidence': None} (score=0.00891) [took 852.1465s] +04/11/23 23:33:44| INFO bin_sld_gs finished [took 1035.2071s] +04/11/23 23:33:44| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs finished [took 1038.1845s] +04/11/23 23:33:44| INFO Dataset sample 0.50 of dataset rcv1_CCAT_3prevs started +04/11/23 23:34:33| INFO ref finished [took 43.6409s] +04/11/23 23:34:37| INFO atc_mc finished [took 46.7818s] +04/11/23 23:34:38| INFO mulmc_sld finished [took 51.3459s] +04/11/23 23:37:15| INFO binmc_sld finished [took 209.5746s] +04/11/23 23:39:48| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00553) [took 359.3210s] +04/11/23 23:40:28| INFO mul_sld_gs finished [took 399.5320s] +04/11/23 23:48:02| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': None} (score=0.01058) [took 855.1289s] +04/11/23 23:51:06| INFO bin_sld_gs finished [took 1038.6344s] +04/11/23 23:51:06| INFO Dataset sample 0.50 of dataset rcv1_CCAT_3prevs finished [took 1041.6478s] +04/11/23 23:51:06| INFO Dataset sample 0.60 of dataset rcv1_CCAT_3prevs started +04/11/23 23:51:51| INFO ref finished [took 40.0694s] +04/11/23 23:51:55| INFO atc_mc finished [took 42.4882s] +04/11/23 23:51:56| INFO mulmc_sld finished [took 47.7936s] +04/11/23 23:54:29| INFO binmc_sld finished [took 201.3777s] +04/11/23 23:57:03| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00429) [took 352.7820s] +04/11/23 23:57:43| INFO mul_sld_gs finished [took 392.5201s] +05/11/23 00:05:21| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1000.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': None, 'confidence': None} (score=0.00552) [took 851.9361s] +05/11/23 00:08:24| INFO bin_sld_gs finished [took 1034.7353s] +05/11/23 00:08:24| INFO Dataset sample 0.60 of dataset rcv1_CCAT_3prevs finished [took 1037.8033s] +---------------------------------------------------------------------------------------------------- +05/11/23 00:11:07| INFO dataset rcv1_CCAT_3prevs +05/11/23 00:11:12| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started diff --git a/quacc/data.py b/quacc/data.py index 1a0ae3f..aa76053 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -1,9 +1,10 @@ +import math from typing import List, Optional import numpy as np -import math import scipy.sparse as sp from quapy.data import LabelledCollection +from sklearn.base import BaseEstimator # Extended classes @@ -128,8 +129,22 @@ class ExtendedCollection(LabelledCollection): @classmethod def extend_collection( - cls, base: LabelledCollection, pred_proba: np.ndarray + cls, + base: LabelledCollection, + classifier: BaseEstimator = None, + pred_proba: np.ndarray = None, ): + if classifier is None and pred_proba is None: + raise AttributeError("classifier and pred_proba cannot be both None") + + if classifier is not None and pred_proba is not None: + raise AttributeError( + "Not needed parameters: just one of classifier or pred_proba is needed" + ) + + if classifier: + pred_proba = classifier.predict_proba(base.X) + n_classes = base.n_classes # n_X = [ X | predicted probs. ] @@ -145,4 +160,3 @@ class ExtendedCollection(LabelledCollection): ) return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)]) - diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index ac6b624..a66f60a 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -8,7 +8,7 @@ from sklearn.linear_model import LogisticRegression import quacc as qc from quacc.evaluation.report import EvaluationReport -from quacc.method.model_selection import GridSearchAE +from quacc.method.model_selection import BQAEgsq, GridSearchAE, MCAEgsq from ..method.base import BQAE, MCAE, BaseAccuracyEstimator @@ -49,8 +49,7 @@ def evaluation_report( @method def bin_sld(c_model, validation, protocol) -> EvaluationReport: - est = BQAE(c_model, SLD(LogisticRegression())) - est.fit(validation) + est = BQAE(c_model, SLD(LogisticRegression())).fit(validation) return evaluation_report( estimator=est, protocol=protocol, @@ -59,8 +58,7 @@ def bin_sld(c_model, validation, protocol) -> EvaluationReport: @method def mul_sld(c_model, validation, protocol) -> EvaluationReport: - est = MCAE(c_model, SLD(LogisticRegression())) - est.fit(validation) + est = MCAE(c_model, SLD(LogisticRegression())).fit(validation) return evaluation_report( estimator=est, protocol=protocol, @@ -68,9 +66,12 @@ def mul_sld(c_model, validation, protocol) -> EvaluationReport: @method -def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport: - est = BQAE(c_model, SLD(LogisticRegression(), recalib="bcts")) - est.fit(validation) +def binmc_sld(c_model, validation, protocol) -> EvaluationReport: + est = BQAE( + c_model, + SLD(LogisticRegression()), + confidence="max_conf", + ).fit(validation) return evaluation_report( estimator=est, protocol=protocol, @@ -78,9 +79,12 @@ def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport: @method -def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport: - est = MCAE(c_model, SLD(LogisticRegression(), recalib="bcts")) - est.fit(validation) +def mulmc_sld(c_model, validation, protocol) -> EvaluationReport: + est = MCAE( + c_model, + SLD(LogisticRegression()), + confidence="max_conf", + ).fit(validation) return evaluation_report( estimator=est, protocol=protocol, @@ -97,10 +101,11 @@ def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport: "q__classifier__C": np.logspace(-3, 3, 7), "q__classifier__class_weight": [None, "balanced"], "q__recalib": [None, "bcts", "vs"], + "confidence": [None, "max_conf"], }, refit=False, protocol=UPP(v_val, repeats=100), - verbose=False, + verbose=True, ).fit(v_train) return evaluation_report( estimator=est, @@ -118,10 +123,11 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: "q__classifier__C": np.logspace(-3, 3, 7), "q__classifier__class_weight": [None, "balanced"], "q__recalib": [None, "bcts", "vs"], + "confidence": [None, "max_conf"], }, refit=False, protocol=UPP(v_val, repeats=100), - verbose=False, + verbose=True, ).fit(v_train) return evaluation_report( estimator=est, @@ -129,10 +135,47 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: ) +@method +def bin_sld_gsq(c_model, validation, protocol) -> EvaluationReport: + est = BQAEgsq( + c_model, + SLD(LogisticRegression()), + param_grid={ + "classifier__C": np.logspace(-3, 3, 7), + "classifier__class_weight": [None, "balanced"], + "recalib": [None, "bcts", "vs"], + }, + refit=False, + verbose=False, + ).fit(validation) + return evaluation_report( + estimator=est, + protocol=protocol, + ) + + +@method +def mul_sld_gsq(c_model, validation, protocol) -> EvaluationReport: + est = MCAEgsq( + c_model, + SLD(LogisticRegression()), + param_grid={ + "classifier__C": np.logspace(-3, 3, 7), + "classifier__class_weight": [None, "balanced"], + "recalib": [None, "bcts", "vs"], + }, + refit=False, + verbose=False, + ).fit(validation) + return evaluation_report( + estimator=est, + protocol=protocol, + ) + + @method def bin_pacc(c_model, validation, protocol) -> EvaluationReport: - est = BQAE(c_model, PACC(LogisticRegression(), recalib="bcts")) - est.fit(validation) + est = BQAE(c_model, PACC(LogisticRegression())).fit(validation) return evaluation_report( estimator=est, protocol=protocol, @@ -141,8 +184,7 @@ def bin_pacc(c_model, validation, protocol) -> EvaluationReport: @method def mul_pacc(c_model, validation, protocol) -> EvaluationReport: - est = MCAE(c_model, PACC(LogisticRegression(), recalib="bcts")) - est.fit(validation) + est = MCAE(c_model, PACC(LogisticRegression())).fit(validation) return evaluation_report( estimator=est, protocol=protocol, @@ -158,7 +200,6 @@ def bin_pacc_gs(c_model, validation, protocol) -> EvaluationReport: param_grid={ "q__classifier__C": np.logspace(-3, 3, 7), "q__classifier__class_weight": [None, "balanced"], - "q__recalib": [None, "bcts", "vs"], }, refit=False, protocol=UPP(v_val, repeats=100), @@ -179,7 +220,6 @@ def mul_pacc_gs(c_model, validation, protocol) -> EvaluationReport: param_grid={ "q__classifier__C": np.logspace(-3, 3, 7), "q__classifier__class_weight": [None, "balanced"], - "q__recalib": [None, "bcts", "vs"], }, refit=False, protocol=UPP(v_val, repeats=100), @@ -188,5 +228,4 @@ def mul_pacc_gs(c_model, validation, protocol) -> EvaluationReport: return evaluation_report( estimator=est, protocol=protocol, - method_name="bin_sld_gs", ) diff --git a/quacc/evaluation/worker.py b/quacc/evaluation/worker.py index 1a96a5f..2ff93f6 100644 --- a/quacc/evaluation/worker.py +++ b/quacc/evaluation/worker.py @@ -27,7 +27,7 @@ def estimate_worker(_estimate, train, validation, test, _env=None, q=None): result = _estimate(model, validation, protocol) except Exception as e: log.warning(f"Method {_estimate.__name__} failed. Exception: {e}") - # traceback(e) + traceback(e) return { "name": _estimate.__name__, "result": None, diff --git a/quacc/method/base.py b/quacc/method/base.py index 8a51362..a57509f 100644 --- a/quacc/method/base.py +++ b/quacc/method/base.py @@ -17,9 +17,11 @@ class BaseAccuracyEstimator(BaseQuantifier): self, classifier: BaseEstimator, quantifier: BaseQuantifier, + confidence=None, ): self.__check_classifier(classifier) self.quantifier = quantifier + self.confidence = confidence def __check_classifier(self, classifier): if not hasattr(classifier, "predict_proba"): @@ -28,10 +30,37 @@ class BaseAccuracyEstimator(BaseQuantifier): ) self.classifier = classifier + def __get_confidence(self): + if self.confidence is None: + return None + + __confs = { + "max_conf": lambda probas: np.max(probas, axis=-1).reshape((len(probas), 1)) + } + return __confs.get(self.confidence, None) + + def __get_ext(self, pred_proba): + _ext = pred_proba + _f_conf = self.__get_confidence() + if _f_conf is not None: + _confs = _f_conf(pred_proba) + _ext = np.concatenate((_confs, pred_proba), axis=1) + + return _ext + def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: - if not pred_proba: + if pred_proba is None: pred_proba = self.classifier.predict_proba(coll.X) - return ExtendedCollection.extend_collection(coll, pred_proba) + + _ext = self.__get_ext(pred_proba) + return ExtendedCollection.extend_collection(coll, pred_proba=_ext) + + def _extend_instances(self, instances: np.ndarray | csr_matrix, pred_proba=None): + if pred_proba is None: + pred_proba = self.classifier.predict_proba(instances) + + _ext = self.__get_ext(pred_proba) + return ExtendedCollection.extend_instances(instances, _ext) @abstractmethod def fit(self, train: LabelledCollection | ExtendedCollection): @@ -47,23 +76,24 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): self, classifier: BaseEstimator, quantifier: BaseQuantifier, + confidence: str = None, ): - super().__init__(classifier, quantifier) + super().__init__( + classifier=classifier, + quantifier=quantifier, + confidence=confidence, + ) self.e_train = None def fit(self, train: LabelledCollection): - pred_probs = self.classifier.predict_proba(train.X) - self.e_train = ExtendedCollection.extend_collection(train, pred_probs) + self.e_train = self.extend(train) self.quantifier.fit(self.e_train) return self def estimate(self, instances, ext=False) -> np.ndarray: - e_inst = instances - if not ext: - pred_prob = self.classifier.predict_proba(instances) - e_inst = ExtendedCollection.extend_instances(instances, pred_prob) + e_inst = instances if ext else self._extend_instances(instances) estim_prev = self.quantifier.quantify(e_inst) return self._check_prevalence_classes(estim_prev) @@ -78,18 +108,25 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator): class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): - def __init__(self, classifier: BaseEstimator, quantifier: BaseAccuracyEstimator): - super().__init__(classifier, quantifier) + def __init__( + self, + classifier: BaseEstimator, + quantifier: BaseAccuracyEstimator, + confidence: str = None, + ): + super().__init__( + classifier=classifier, + quantifier=quantifier, + confidence=confidence, + ) self.quantifiers = [] self.e_trains = [] def fit(self, train: LabelledCollection | ExtendedCollection): - pred_probs = self.classifier.predict_proba(train.X) - self.e_train = ExtendedCollection.extend_collection(train, pred_probs) + self.e_train = self.extend(train) self.n_classes = self.e_train.n_classes self.e_trains = self.e_train.split_by_pred() - self.quantifiers = [deepcopy(self.quantifier) for _ in self.e_trains] self.quantifiers = [] for train in self.e_trains: @@ -97,12 +134,11 @@ class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): quant.fit(train) self.quantifiers.append(quant) + return self + def estimate(self, instances, ext=False): # TODO: test - e_inst = instances - if not ext: - pred_prob = self.classifier.predict_proba(instances) - e_inst = ExtendedCollection.extend_instances(instances, pred_prob) + e_inst = instances if ext else self._extend_instances(instances) _ncl = int(math.sqrt(self.n_classes)) s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst) diff --git a/quacc/method/model_selection.py b/quacc/method/model_selection.py index ba866f6..2db3f67 100644 --- a/quacc/method/model_selection.py +++ b/quacc/method/model_selection.py @@ -3,14 +3,22 @@ from copy import deepcopy from time import time from typing import Callable, Union +import quapy as qp from quapy.data import LabelledCollection -from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol +from quapy.model_selection import GridSearchQ +from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol +from sklearn.base import BaseEstimator import quacc as qc import quacc.error from quacc.data import ExtendedCollection from quacc.evaluation import evaluate -from quacc.method.base import BaseAccuracyEstimator +from quacc.logger import SubLogger +from quacc.method.base import ( + BaseAccuracyEstimator, + BinaryQuantifierAccuracyEstimator, + MultiClassAccuracyEstimator, +) class GridSearchAE(BaseAccuracyEstimator): @@ -106,6 +114,12 @@ class GridSearchAE(BaseAccuracyEstimator): f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " f"[took {tend:.4f}s]" ) + log = SubLogger.logger() + log.debug( + f"[{self.model.__class__.__name__}] " + f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " + f"[took {tend:.4f}s]" + ) if self.refit: if isinstance(protocol, OnLabelledCollectionProtocol): @@ -203,3 +217,84 @@ class GridSearchAE(BaseAccuracyEstimator): if hasattr(self, "best_model_"): return self.best_model_ raise ValueError("best_model called before fit") + + +class MCAEgsq(MultiClassAccuracyEstimator): + def __init__( + self, + classifier: BaseEstimator, + quantifier: BaseAccuracyEstimator, + param_grid: dict, + error: Union[Callable, str] = qp.error.mae, + refit=True, + timeout=-1, + n_jobs=None, + verbose=False, + ): + self.param_grid = param_grid + self.refit = refit + self.timeout = timeout + self.n_jobs = n_jobs + self.verbose = verbose + self.error = error + super().__init__(classifier, quantifier) + + def fit(self, train: LabelledCollection): + self.e_train = self.extend(train) + t_train, t_val = self.e_train.split_stratified(0.6, random_state=0) + self.quantifier = GridSearchQ( + deepcopy(self.quantifier), + param_grid=self.param_grid, + protocol=UPP(t_val, repeats=100), + error=self.error, + refit=self.refit, + timeout=self.timeout, + n_jobs=self.n_jobs, + verbose=self.verbose, + ).fit(self.e_train) + + return self + + +class BQAEgsq(BinaryQuantifierAccuracyEstimator): + def __init__( + self, + classifier: BaseEstimator, + quantifier: BaseAccuracyEstimator, + param_grid: dict, + error: Union[Callable, str] = qp.error.mae, + refit=True, + timeout=-1, + n_jobs=None, + verbose=False, + ): + self.param_grid = param_grid + self.refit = refit + self.timeout = timeout + self.n_jobs = n_jobs + self.verbose = verbose + self.error = error + super().__init__(classifier=classifier, quantifier=quantifier) + + def fit(self, train: LabelledCollection): + self.e_train = self.extend(train) + + self.n_classes = self.e_train.n_classes + self.e_trains = self.e_train.split_by_pred() + + self.quantifiers = [] + for e_train in self.e_trains: + t_train, t_val = e_train.split_stratified(0.6, random_state=0) + quantifier = GridSearchQ( + model=deepcopy(self.quantifier), + param_grid=self.param_grid, + protocol=UPP(t_val, repeats=100), + error=self.error, + refit=self.refit, + timeout=self.timeout, + n_jobs=self.n_jobs, + verbose=self.verbose, + ).fit(t_train) + self.quantifiers.append(quantifier) + + return self