method confidence added

This commit is contained in:
Lorenzo Volpi 2023-11-05 00:15:40 +01:00
parent 71ae40d63e
commit b96432f87b
8 changed files with 938 additions and 71 deletions

View File

@ -1,41 +1,44 @@
import numpy as np
import numpy as np
from sklearn.metrics import f1_score
def get_entropy(probs):
return np.sum( np.multiply(probs, np.log(probs + 1e-20)) , axis=1)
def get_entropy(probs):
return np.sum(np.multiply(probs, np.log(probs + 1e-20)), axis=1)
def get_max_conf(probs):
return np.max(probs, axis=-1)
def find_ATC_threshold(scores, labels):
return np.max(probs, axis=-1)
def find_ATC_threshold(scores, labels):
sorted_idx = np.argsort(scores)
sorted_scores = scores[sorted_idx]
sorted_labels = labels[sorted_idx]
fp = np.sum(labels==0)
fp = np.sum(labels == 0)
fn = 0.0
min_fp_fn = np.abs(fp - fn)
thres = 0.0
for i in range(len(labels)):
if sorted_labels[i] == 0:
for i in range(len(labels)):
if sorted_labels[i] == 0:
fp -= 1
else:
else:
fn += 1
if np.abs(fp - fn) < min_fp_fn:
if np.abs(fp - fn) < min_fp_fn:
min_fp_fn = np.abs(fp - fn)
thres = sorted_scores[i]
return min_fp_fn, thres
def get_ATC_acc(thres, scores):
return np.mean(scores>=thres)
def get_ATC_acc(thres, scores):
return np.mean(scores >= thres)
def get_ATC_f1(thres, scores, probs):
preds = np.argmax(probs, axis=-1)
estim_y = abs(1 - (scores>=thres)^preds)
estim_y = np.abs(1 - (scores >= thres) ^ preds)
return f1_score(estim_y, preds)

View File

@ -12,7 +12,30 @@ debug_conf: &debug_conf
plot_confs:
debug:
PLOT_ESTIMATORS:
- mul_sld
- bin_sld_gs
PLOT_STDEV: true
mc_conf: &mc_conf
global:
METRICS:
- acc
DATASET_N_PREVS: 9
DATASET_PREVS:
- 0.4
- 0.5
- 0.6
confs:
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
plot_confs:
debug:
PLOT_ESTIMATORS:
- mulmc_sld
- mul_sld_gs
- bin_sld
- bin_sld_gs
- atc_mc
PLOT_STDEV: true
@ -29,13 +52,20 @@ test_conf: &test_conf
# - DATASET_NAME: imdb
plot_confs:
2gs_vs_atc:
gs_vs_gsq:
PLOT_ESTIMATORS:
- bin_sld
- bin_sld_gs
- bin_sld_qgs
- bin_sld_gsq
- mul_sld
- mul_sld_gs
- mul_sld_gsq
gs_vs_atc:
PLOT_ESTIMATORS:
- bin_sld
- bin_sld_gs
- mul_sld
- mul_sld_gs
- mul_sld_qgs
- ref
- atc_mc
- atc_ne
sld_vs_pacc:
@ -44,11 +74,23 @@ test_conf: &test_conf
- bin_sld_gs
- mul_sld
- mul_sld_gs
- ref
- bin_pacc
- bin_pacc_gs
- mul_pacc
- mul_pacc_gs
- atc_mc
- atc_ne
pacc_vs_atc:
PLOT_ESTIMATORS:
- bin_pacc
- bin_pacc_gs
- mul_pacc
- mul_pacc_gs
- atc_mc
- atc_ne
main_conf: &main_conf
global:
METRICS:
- acc
@ -106,4 +148,4 @@ main_conf: &main_conf
- atc_ne
- doc_feat
exec: *debug_conf
exec: *mc_conf

638
quacc.log
View File

@ -1636,3 +1636,641 @@
04/11/23 00:05:20| INFO atc_mc finished [took 26.4278s]
04/11/23 00:05:29| INFO mul_sld finished [took 35.3110s]
04/11/23 00:05:29| INFO Dataset sample 0.50 of dataset imdb_1prevs finished [took 36.4422s]
----------------------------------------------------------------------------------------------------
04/11/23 00:19:43| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:19:49| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:19:53| WARNING Method bin_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input.
04/11/23 00:19:55| WARNING Method mul_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input.
04/11/23 00:19:57| WARNING Method bin_pacc failed. Exception: PACC.__init__() got an unexpected keyword argument 'recalib'
04/11/23 00:19:59| WARNING Method mul_pacc failed. Exception: PACC.__init__() got an unexpected keyword argument 'recalib'
04/11/23 00:20:00| WARNING Method bin_pacc_gs failed. Exception: Invalid parameter 'recalib' for estimator PACC(classifier=LogisticRegression(), n_jobs=1). Valid parameters are: ['classifier', 'n_jobs', 'val_split'].
04/11/23 00:20:01| WARNING Method mul_pacc_gs failed. Exception: Invalid parameter 'recalib' for estimator PACC(classifier=LogisticRegression(), n_jobs=1). Valid parameters are: ['classifier', 'n_jobs', 'val_split'].
----------------------------------------------------------------------------------------------------
04/11/23 00:22:45| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:22:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:22:54| WARNING Method bin_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input.
04/11/23 00:22:55| WARNING Method mul_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input.
----------------------------------------------------------------------------------------------------
04/11/23 00:28:11| INFO dataset rcv1_CCAT_9prevs
----------------------------------------------------------------------------------------------------
04/11/23 00:29:39| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:29:45| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:29:49| WARNING Method bin_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input.
04/11/23 00:29:51| WARNING Method mul_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input.
04/11/23 00:30:39| WARNING Method mul_pacc_gs failed. Exception: evaluation_report() got an unexpected keyword argument 'method_name'
04/11/23 00:31:00| INFO ref finished [took 60.5788s]
04/11/23 00:31:09| INFO atc_mc finished [took 64.6156s]
04/11/23 00:31:09| INFO mul_pacc finished [took 75.1821s]
04/11/23 00:31:12| INFO atc_ne finished [took 62.8665s]
04/11/23 00:31:24| INFO mul_sld finished [took 96.8624s]
----------------------------------------------------------------------------------------------------
04/11/23 00:33:26| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:33:31| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:33:35| WARNING Method bin_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input.
04/11/23 00:33:37| WARNING Method mul_sld_qgs failed. Exception: X has 47236 features, but LogisticRegression is expecting 47238 features as input.
----------------------------------------------------------------------------------------------------
04/11/23 00:38:42| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:38:48| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:38:51| WARNING Method bin_sld_qgs failed. Exception: ExtendedCollection.extend_collection() missing 1 required positional argument: 'pred_proba'
04/11/23 00:38:52| WARNING Method mul_sld_qgs failed. Exception: ExtendedCollection.extend_collection() missing 1 required positional argument: 'pred_proba'
04/11/23 00:39:41| WARNING Method mul_pacc_gs failed. Exception: evaluation_report() got an unexpected keyword argument 'method_name'
----------------------------------------------------------------------------------------------------
04/11/23 00:46:33| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:46:39| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:46:40| WARNING Method bin_sld failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:46:41| WARNING Method mul_sld failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:46:42| WARNING Method bin_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba'
04/11/23 00:46:43| WARNING Method mul_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba'
04/11/23 00:46:44| WARNING Method bin_sld_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:46:45| WARNING Method mul_sld_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:46:46| WARNING Method bin_pacc failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:46:47| WARNING Method mul_pacc failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:46:47| WARNING Method bin_pacc_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:46:48| WARNING Method mul_pacc_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:47:27| INFO ref finished [took 37.5294s]
04/11/23 00:47:31| INFO atc_mc finished [took 40.5777s]
04/11/23 00:47:32| INFO atc_ne finished [took 40.7565s]
04/11/23 00:47:32| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 52.8106s]
04/11/23 00:47:32| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
04/11/23 00:47:33| WARNING Method bin_sld failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:47:34| WARNING Method mul_sld failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:47:35| WARNING Method bin_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba'
04/11/23 00:47:36| WARNING Method mul_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba'
04/11/23 00:47:37| WARNING Method bin_sld_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:47:38| WARNING Method mul_sld_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:47:39| WARNING Method bin_pacc failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:47:39| WARNING Method mul_pacc failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:47:40| WARNING Method bin_pacc_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
04/11/23 00:47:41| WARNING Method mul_pacc_gs failed. Exception: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
----------------------------------------------------------------------------------------------------
04/11/23 00:48:05| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:48:10| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:48:13| WARNING Method bin_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba'
04/11/23 00:48:14| WARNING Method mul_sld_qgs failed. Exception: 'LogisticRegression' object has no attribute 'pred_proba'
----------------------------------------------------------------------------------------------------
04/11/23 00:49:18| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:49:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:49:27| WARNING Method bin_sld_qgs failed. Exception: GridSearchQ.__init__() missing 1 required positional argument: 'model'
04/11/23 00:49:28| WARNING Method mul_sld_qgs failed. Exception: X has 47238 features, but LogisticRegression is expecting 47236 features as input.
----------------------------------------------------------------------------------------------------
04/11/23 00:51:27| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:51:32| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:51:36| WARNING Method bin_sld_qgs failed. Exception: X has 47238 features, but LogisticRegression is expecting 47236 features as input.
04/11/23 00:51:37| WARNING Method mul_sld_qgs failed. Exception: X has 47238 features, but LogisticRegression is expecting 47236 features as input.
----------------------------------------------------------------------------------------------------
04/11/23 00:54:47| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:54:53| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 00:54:57| WARNING Method bin_sld_qgs failed. Exception: a must be greater than 0 unless no samples are taken
04/11/23 00:54:58| WARNING Method mul_sld_qgs failed. Exception: a must be greater than 0 unless no samples are taken
----------------------------------------------------------------------------------------------------
04/11/23 00:58:47| INFO dataset rcv1_CCAT_9prevs
04/11/23 00:58:52| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 01:00:04| INFO ref finished [took 61.6328s]
04/11/23 01:00:11| INFO atc_mc finished [took 65.4916s]
04/11/23 01:00:13| INFO atc_ne finished [took 63.2288s]
04/11/23 01:00:14| INFO mul_pacc finished [took 75.5101s]
04/11/23 01:00:30| INFO mul_sld finished [took 96.6656s]
04/11/23 01:00:41| INFO mul_pacc_gs finished [took 99.7211s]
04/11/23 01:03:02| INFO bin_pacc finished [took 244.6260s]
04/11/23 01:03:07| INFO bin_sld finished [took 254.3478s]
04/11/23 01:04:51| INFO mul_sld_gs finished [took 354.7477s]
04/11/23 01:05:02| INFO bin_pacc_gs finished [took 362.1808s]
04/11/23 01:09:24| INFO bin_sld_gs finished [took 628.6714s]
04/11/23 01:09:24| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 631.8421s]
04/11/23 01:09:24| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
04/11/23 01:10:39| INFO ref finished [took 63.5158s]
04/11/23 01:10:44| INFO atc_mc finished [took 66.4279s]
04/11/23 01:10:46| INFO mul_pacc finished [took 75.3281s]
04/11/23 01:10:47| INFO atc_ne finished [took 67.5374s]
04/11/23 01:10:52| INFO mul_sld finished [took 86.6592s]
04/11/23 01:11:19| INFO mul_pacc_gs finished [took 104.6374s]
04/11/23 01:13:58| INFO bin_sld finished [took 273.4932s]
04/11/23 01:14:01| INFO bin_pacc finished [took 271.3481s]
04/11/23 01:15:42| INFO mul_sld_gs finished [took 374.2416s]
04/11/23 01:16:01| INFO bin_pacc_gs finished [took 388.0839s]
04/11/23 01:20:29| INFO bin_sld_gs finished [took 661.9729s]
04/11/23 01:20:29| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 665.2874s]
04/11/23 01:20:29| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started
04/11/23 01:21:46| INFO ref finished [took 63.8544s]
04/11/23 01:21:50| INFO atc_mc finished [took 66.6917s]
04/11/23 01:21:52| INFO atc_ne finished [took 65.0860s]
04/11/23 01:21:53| INFO mul_pacc finished [took 77.2630s]
04/11/23 01:21:55| INFO mul_sld finished [took 83.3146s]
04/11/23 01:22:23| INFO mul_pacc_gs finished [took 102.3761s]
04/11/23 01:24:47| INFO bin_pacc finished [took 252.0964s]
04/11/23 01:24:49| INFO bin_sld finished [took 258.6998s]
04/11/23 01:26:37| INFO mul_sld_gs finished [took 363.7500s]
04/11/23 01:26:49| INFO bin_pacc_gs finished [took 370.5817s]
04/11/23 01:31:27| INFO bin_sld_gs finished [took 654.3921s]
04/11/23 01:31:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 658.0041s]
04/11/23 01:31:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
04/11/23 01:32:33| INFO ref finished [took 55.7749s]
04/11/23 01:32:38| INFO atc_mc finished [took 59.4190s]
04/11/23 01:32:40| INFO atc_ne finished [took 59.5155s]
04/11/23 01:32:42| INFO mul_pacc finished [took 68.8994s]
04/11/23 01:32:44| INFO mul_sld finished [took 74.6470s]
04/11/23 01:33:09| INFO mul_pacc_gs finished [took 92.6473s]
04/11/23 01:35:32| INFO bin_pacc finished [took 239.7541s]
04/11/23 01:35:34| INFO bin_sld finished [took 245.7504s]
04/11/23 01:37:19| INFO mul_sld_gs finished [took 348.1188s]
04/11/23 01:37:30| INFO bin_pacc_gs finished [took 355.4729s]
04/11/23 01:42:07| INFO bin_sld_gs finished [took 636.8598s]
04/11/23 01:42:07| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 639.9201s]
04/11/23 01:42:07| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
04/11/23 01:43:14| INFO ref finished [took 56.1531s]
04/11/23 01:43:19| INFO atc_mc finished [took 59.7473s]
04/11/23 01:43:20| INFO atc_ne finished [took 59.0606s]
04/11/23 01:43:23| INFO mul_pacc finished [took 69.4266s]
04/11/23 01:43:25| INFO mul_sld finished [took 76.3328s]
04/11/23 01:43:49| INFO mul_pacc_gs finished [took 92.3926s]
04/11/23 01:46:05| INFO bin_pacc finished [took 233.1877s]
04/11/23 01:46:08| INFO bin_sld finished [took 239.8757s]
04/11/23 01:47:51| INFO mul_sld_gs finished [took 339.5911s]
04/11/23 01:48:00| INFO bin_pacc_gs finished [took 345.7788s]
04/11/23 01:52:44| INFO bin_sld_gs finished [took 633.8407s]
04/11/23 01:52:44| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 637.0648s]
04/11/23 01:52:44| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
04/11/23 01:53:52| INFO ref finished [took 57.4958s]
04/11/23 01:53:57| INFO atc_mc finished [took 60.9998s]
04/11/23 01:53:58| INFO atc_ne finished [took 60.4847s]
04/11/23 01:54:01| INFO mul_pacc finished [took 70.5216s]
04/11/23 01:54:04| INFO mul_sld finished [took 78.2910s]
04/11/23 01:54:27| INFO mul_pacc_gs finished [took 94.4726s]
04/11/23 01:56:48| INFO bin_pacc finished [took 238.5969s]
04/11/23 01:56:50| INFO bin_sld finished [took 244.5679s]
04/11/23 01:58:31| INFO mul_sld_gs finished [took 342.4843s]
04/11/23 01:58:44| INFO bin_pacc_gs finished [took 352.8264s]
04/11/23 02:03:32| INFO bin_sld_gs finished [took 644.7046s]
04/11/23 02:03:32| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 647.8055s]
04/11/23 02:03:32| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
04/11/23 02:04:37| INFO ref finished [took 55.4488s]
04/11/23 02:04:42| INFO atc_mc finished [took 59.2634s]
04/11/23 02:04:44| INFO atc_ne finished [took 59.1371s]
04/11/23 02:04:46| INFO mul_pacc finished [took 68.0960s]
04/11/23 02:04:50| INFO mul_sld finished [took 76.4282s]
04/11/23 02:05:12| INFO mul_pacc_gs finished [took 91.7735s]
04/11/23 02:07:30| INFO bin_pacc finished [took 232.7650s]
04/11/23 02:07:36| INFO bin_sld finished [took 242.4077s]
04/11/23 02:09:14| INFO mul_sld_gs finished [took 338.1418s]
04/11/23 02:09:26| INFO bin_pacc_gs finished [took 347.2033s]
04/11/23 02:13:59| INFO bin_sld_gs finished [took 624.6098s]
04/11/23 02:13:59| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 627.7979s]
04/11/23 02:13:59| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
04/11/23 02:15:05| INFO ref finished [took 55.1962s]
04/11/23 02:15:10| INFO atc_mc finished [took 59.0907s]
04/11/23 02:15:11| INFO atc_ne finished [took 59.1531s]
04/11/23 02:15:13| INFO mul_pacc finished [took 67.6705s]
04/11/23 02:15:17| INFO mul_sld finished [took 75.4559s]
04/11/23 02:15:41| INFO mul_pacc_gs finished [took 92.4901s]
04/11/23 02:17:59| INFO bin_pacc finished [took 233.8600s]
04/11/23 02:18:04| INFO bin_sld finished [took 243.2382s]
04/11/23 02:19:40| INFO mul_sld_gs finished [took 336.0961s]
04/11/23 02:19:51| INFO bin_pacc_gs finished [took 344.4075s]
04/11/23 02:24:30| INFO bin_sld_gs finished [took 627.6209s]
04/11/23 02:24:30| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 630.8251s]
04/11/23 02:24:30| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
04/11/23 02:25:35| INFO ref finished [took 54.8513s]
04/11/23 02:25:40| INFO atc_mc finished [took 58.8528s]
04/11/23 02:25:41| INFO atc_ne finished [took 58.6035s]
04/11/23 02:25:43| INFO mul_pacc finished [took 66.9030s]
04/11/23 02:25:57| INFO mul_sld finished [took 84.2072s]
04/11/23 02:26:10| INFO mul_pacc_gs finished [took 91.0973s]
04/11/23 02:28:31| INFO bin_pacc finished [took 235.7331s]
04/11/23 02:28:35| INFO bin_sld finished [took 243.6260s]
04/11/23 02:30:09| INFO mul_sld_gs finished [took 334.4842s]
04/11/23 02:30:22| INFO bin_pacc_gs finished [took 344.6874s]
04/11/23 02:34:46| INFO bin_sld_gs finished [took 612.1219s]
04/11/23 02:34:46| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 615.2004s]
----------------------------------------------------------------------------------------------------
04/11/23 02:57:35| INFO dataset rcv1_CCAT_9prevs
04/11/23 02:57:39| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 02:57:47| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 02:58:59| INFO ref finished [took 64.5948s]
04/11/23 02:59:06| INFO atc_mc finished [took 69.5808s]
04/11/23 02:59:12| INFO mul_pacc finished [took 82.8518s]
04/11/23 02:59:13| INFO atc_ne finished [took 72.1303s]
04/11/23 02:59:26| INFO mul_sld finished [took 103.4201s]
04/11/23 02:59:30| WARNING Method bin_sld_gsq failed. Exception: This solver needs samples of at least 2 classes in the data, but the data contains only one class: 1
04/11/23 02:59:41| INFO mul_pacc_gs finished [took 109.4672s]
04/11/23 03:01:59| INFO bin_pacc finished [took 251.3945s]
04/11/23 03:02:02| INFO bin_sld finished [took 260.0226s]
04/11/23 03:03:35| INFO mul_sld_gs finished [took 350.1705s]
04/11/23 03:03:48| INFO bin_pacc_gs finished [took 357.9668s]
04/11/23 03:07:59| INFO bin_sld_gs finished [took 615.8087s]
04/11/23 03:07:59| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 620.4985s]
04/11/23 03:07:59| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
04/11/23 03:08:06| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 03:09:17| INFO ref finished [took 64.4692s]
04/11/23 03:09:25| INFO atc_mc finished [took 71.3766s]
04/11/23 03:09:27| INFO atc_ne finished [took 71.0947s]
04/11/23 03:09:28| INFO mul_pacc finished [took 80.0201s]
04/11/23 03:09:31| INFO mul_sld finished [took 89.4295s]
04/11/23 03:09:55| INFO mul_pacc_gs finished [took 104.7292s]
04/11/23 03:12:25| INFO bin_sld finished [took 263.6824s]
04/11/23 03:12:25| INFO bin_pacc finished [took 258.6502s]
04/11/23 03:14:01| INFO mul_sld_gs finished [took 357.3344s]
04/11/23 03:14:14| INFO bin_sld_gsq finished [took 369.1636s]
04/11/23 03:14:22| INFO bin_pacc_gs finished [took 372.8646s]
04/11/23 03:18:40| INFO bin_sld_gs finished [took 636.9190s]
04/11/23 03:18:40| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 640.2322s]
04/11/23 03:18:40| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started
04/11/23 03:18:46| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 03:19:58| INFO ref finished [took 65.9462s]
04/11/23 03:20:02| INFO atc_mc finished [took 68.5710s]
04/11/23 03:20:04| INFO atc_ne finished [took 68.9466s]
04/11/23 03:20:06| INFO mul_pacc finished [took 77.9039s]
04/11/23 03:20:06| INFO mul_sld finished [took 84.0917s]
04/11/23 03:20:37| INFO mul_pacc_gs finished [took 106.2536s]
04/11/23 03:23:04| INFO bin_pacc finished [took 257.4211s]
04/11/23 03:23:05| INFO bin_sld finished [took 264.3442s]
04/11/23 03:24:49| INFO mul_sld_gs finished [took 365.1691s]
04/11/23 03:25:01| INFO bin_pacc_gs finished [took 371.9184s]
04/11/23 03:25:02| INFO bin_sld_gsq finished [took 377.0442s]
04/11/23 03:29:37| INFO bin_sld_gs finished [took 654.0366s]
04/11/23 03:29:37| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 657.0840s]
04/11/23 03:29:37| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
04/11/23 03:29:42| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 03:30:51| INFO ref finished [took 62.7217s]
04/11/23 03:30:58| INFO atc_mc finished [took 67.8613s]
04/11/23 03:31:00| INFO atc_ne finished [took 68.5026s]
04/11/23 03:31:03| INFO mul_sld finished [took 83.8857s]
04/11/23 03:31:03| INFO mul_pacc finished [took 78.6340s]
04/11/23 03:31:30| INFO mul_pacc_gs finished [took 103.4683s]
04/11/23 03:34:00| INFO bin_sld finished [took 262.4457s]
04/11/23 03:34:02| INFO bin_pacc finished [took 258.2247s]
04/11/23 03:35:44| INFO mul_sld_gs finished [took 363.8135s]
04/11/23 03:35:58| INFO bin_pacc_gs finished [took 372.0485s]
04/11/23 03:36:05| INFO bin_sld_gsq finished [took 382.9585s]
04/11/23 03:40:39| INFO bin_sld_gs finished [took 659.6222s]
04/11/23 03:40:39| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 662.5763s]
04/11/23 03:40:39| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
04/11/23 03:40:45| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 03:41:56| INFO ref finished [took 64.5923s]
04/11/23 03:42:01| INFO atc_mc finished [took 68.0148s]
04/11/23 03:42:03| INFO atc_ne finished [took 68.3119s]
04/11/23 03:42:04| INFO mul_pacc finished [took 76.9397s]
04/11/23 03:42:07| INFO mul_sld finished [took 85.5363s]
04/11/23 03:42:34| INFO mul_pacc_gs finished [took 103.4448s]
04/11/23 03:45:01| INFO bin_sld finished [took 260.0814s]
04/11/23 03:45:03| INFO bin_pacc finished [took 256.9386s]
04/11/23 03:46:45| INFO mul_sld_gs finished [took 361.5910s]
04/11/23 03:47:01| INFO bin_pacc_gs finished [took 371.9657s]
04/11/23 03:47:13| INFO bin_sld_gsq finished [took 388.2498s]
04/11/23 03:51:40| INFO bin_sld_gs finished [took 657.4008s]
04/11/23 03:51:40| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 660.5115s]
04/11/23 03:51:40| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
04/11/23 03:51:46| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 03:52:54| INFO ref finished [took 61.9225s]
04/11/23 03:53:00| INFO atc_mc finished [took 66.3156s]
04/11/23 03:53:02| INFO atc_ne finished [took 66.5025s]
04/11/23 03:53:04| INFO mul_pacc finished [took 75.8808s]
04/11/23 03:53:06| INFO mul_sld finished [took 84.3204s]
04/11/23 03:53:33| INFO mul_pacc_gs finished [took 102.5763s]
04/11/23 03:56:04| INFO bin_sld finished [took 263.2781s]
04/11/23 03:56:04| INFO bin_pacc finished [took 257.7298s]
04/11/23 03:57:44| INFO mul_sld_gs finished [took 359.7910s]
04/11/23 03:58:00| INFO bin_pacc_gs finished [took 371.3848s]
04/11/23 03:58:11| INFO bin_sld_gsq finished [took 386.0904s]
04/11/23 04:02:50| INFO bin_sld_gs finished [took 667.6623s]
04/11/23 04:02:50| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 670.7255s]
04/11/23 04:02:50| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
04/11/23 04:02:57| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 04:04:05| INFO ref finished [took 62.3256s]
04/11/23 04:04:13| INFO atc_mc finished [took 68.9525s]
04/11/23 04:04:15| INFO atc_ne finished [took 68.8750s]
04/11/23 04:04:16| INFO mul_pacc finished [took 77.5049s]
04/11/23 04:04:19| INFO mul_sld finished [took 86.0694s]
04/11/23 04:04:45| INFO mul_pacc_gs finished [took 103.3513s]
04/11/23 04:07:15| INFO bin_pacc finished [took 257.6456s]
04/11/23 04:07:16| INFO bin_sld finished [took 263.9914s]
04/11/23 04:08:55| INFO mul_sld_gs finished [took 360.5634s]
04/11/23 04:09:12| INFO bin_pacc_gs finished [took 372.2665s]
04/11/23 04:09:18| INFO bin_sld_gsq finished [took 381.8311s]
04/11/23 04:13:39| INFO bin_sld_gs finished [took 645.3599s]
04/11/23 04:13:39| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 648.5328s]
04/11/23 04:13:39| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
04/11/23 04:13:45| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 04:14:51| INFO ref finished [took 59.8110s]
04/11/23 04:14:58| INFO atc_mc finished [took 65.2666s]
04/11/23 04:14:59| INFO atc_ne finished [took 64.5173s]
04/11/23 04:15:01| INFO mul_pacc finished [took 73.8332s]
04/11/23 04:15:04| INFO mul_sld finished [took 82.3509s]
04/11/23 04:15:29| INFO mul_pacc_gs finished [took 99.3541s]
04/11/23 04:18:00| INFO bin_pacc finished [took 254.3308s]
04/11/23 04:18:03| INFO bin_sld finished [took 262.3008s]
04/11/23 04:19:40| INFO mul_sld_gs finished [took 357.1229s]
04/11/23 04:19:57| INFO bin_pacc_gs finished [took 368.4516s]
04/11/23 04:20:03| INFO bin_sld_gsq finished [took 378.7658s]
04/11/23 04:24:37| INFO bin_sld_gs finished [took 655.1931s]
04/11/23 04:24:37| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 658.3505s]
04/11/23 04:24:37| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
04/11/23 04:24:43| WARNING Method mul_sld_gsq failed. Exception: MultiClassAccuracyEstimator.__init__() got an unexpected keyword argument 'param_grid'
04/11/23 04:25:49| INFO ref finished [took 59.4546s]
04/11/23 04:25:55| INFO atc_mc finished [took 63.5805s]
04/11/23 04:25:58| INFO atc_ne finished [took 63.2985s]
04/11/23 04:25:58| INFO mul_pacc finished [took 72.5198s]
04/11/23 04:26:11| INFO mul_sld finished [took 91.7136s]
04/11/23 04:26:27| INFO mul_pacc_gs finished [took 98.8722s]
04/11/23 04:28:57| INFO bin_pacc finished [took 252.8144s]
04/11/23 04:29:02| INFO bin_sld finished [took 263.8013s]
04/11/23 04:30:35| INFO mul_sld_gs finished [took 353.3693s]
04/11/23 04:30:51| INFO bin_sld_gsq finished [took 368.8564s]
04/11/23 04:30:54| INFO bin_pacc_gs finished [took 367.5592s]
04/11/23 04:35:11| INFO bin_sld_gs finished [took 630.6700s]
04/11/23 04:35:11| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 633.7494s]
----------------------------------------------------------------------------------------------------
04/11/23 19:09:42| INFO dataset rcv1_CCAT_9prevs
04/11/23 19:09:47| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 19:10:28| INFO ref finished [took 36.0351s]
04/11/23 19:10:32| INFO atc_mc finished [took 38.9507s]
04/11/23 19:10:35| INFO mulmc_sld finished [took 43.7869s]
04/11/23 19:10:50| INFO mul_sld finished [took 60.8007s]
04/11/23 19:10:50| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 62.9600s]
04/11/23 19:10:50| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
04/11/23 19:11:29| INFO ref finished [took 36.3632s]
04/11/23 19:11:34| INFO atc_mc finished [took 39.5928s]
04/11/23 19:11:36| INFO mulmc_sld finished [took 44.2915s]
04/11/23 19:11:44| INFO mul_sld finished [took 52.6727s]
04/11/23 19:11:44| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 54.0362s]
04/11/23 19:11:44| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started
04/11/23 19:12:24| INFO ref finished [took 36.4303s]
04/11/23 19:12:27| INFO atc_mc finished [took 39.2329s]
04/11/23 19:12:30| INFO mulmc_sld finished [took 43.6247s]
04/11/23 19:12:36| INFO mul_sld finished [took 50.2041s]
04/11/23 19:12:36| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 51.6412s]
04/11/23 19:12:36| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
04/11/23 19:13:16| INFO ref finished [took 36.7551s]
04/11/23 19:13:19| INFO atc_mc finished [took 39.2806s]
04/11/23 19:13:21| INFO mulmc_sld finished [took 43.6120s]
04/11/23 19:13:27| INFO mul_sld finished [took 50.4446s]
04/11/23 19:13:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 51.6672s]
04/11/23 19:13:27| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
04/11/23 19:14:07| INFO ref finished [took 35.8789s]
04/11/23 19:14:11| INFO atc_mc finished [took 39.2168s]
04/11/23 19:14:13| INFO mulmc_sld finished [took 43.4580s]
04/11/23 19:14:20| INFO mul_sld finished [took 51.2902s]
04/11/23 19:14:20| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 52.6303s]
04/11/23 19:14:20| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
04/11/23 19:15:00| INFO ref finished [took 36.3735s]
04/11/23 19:15:04| INFO atc_mc finished [took 39.7035s]
04/11/23 19:15:06| INFO mulmc_sld finished [took 43.6364s]
04/11/23 19:15:13| INFO mul_sld finished [took 52.0138s]
04/11/23 19:15:13| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 53.3303s]
04/11/23 19:15:13| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
04/11/23 19:15:54| INFO ref finished [took 37.3366s]
04/11/23 19:15:57| INFO atc_mc finished [took 39.8921s]
04/11/23 19:16:00| INFO mulmc_sld finished [took 44.5159s]
04/11/23 19:16:08| INFO mul_sld finished [took 53.0806s]
04/11/23 19:16:08| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 54.4117s]
04/11/23 19:16:08| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
04/11/23 19:16:47| INFO ref finished [took 35.7800s]
04/11/23 19:16:50| INFO atc_mc finished [took 38.4484s]
04/11/23 19:16:53| INFO mulmc_sld finished [took 42.7405s]
04/11/23 19:17:01| INFO mul_sld finished [took 51.5556s]
04/11/23 19:17:01| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 52.9684s]
04/11/23 19:17:01| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
04/11/23 19:17:39| INFO ref finished [took 35.0919s]
04/11/23 19:17:43| INFO atc_mc finished [took 38.1718s]
04/11/23 19:17:45| INFO mulmc_sld finished [took 42.4413s]
04/11/23 19:17:59| INFO mul_sld finished [took 57.0766s]
04/11/23 19:17:59| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 58.3668s]
----------------------------------------------------------------------------------------------------
04/11/23 19:42:38| INFO dataset rcv1_CCAT_9prevs
04/11/23 19:42:43| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 19:43:27| INFO ref finished [took 38.7664s]
04/11/23 19:43:31| INFO atc_mc finished [took 42.4000s]
04/11/23 19:43:33| INFO mulmc_sld finished [took 47.0913s]
04/11/23 19:43:34| INFO binmc_sld finished [took 47.1675s]
04/11/23 19:43:49| INFO mul_sld finished [took 64.1382s]
04/11/23 19:46:00| INFO bin_sld finished [took 195.9822s]
04/11/23 19:46:00| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 197.2916s]
04/11/23 19:46:00| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
04/11/23 19:46:44| INFO ref finished [took 38.5976s]
04/11/23 19:46:48| INFO atc_mc finished [took 41.9465s]
04/11/23 19:46:49| INFO mulmc_sld finished [took 46.2205s]
04/11/23 19:46:51| INFO binmc_sld finished [took 46.7475s]
04/11/23 19:46:58| INFO mul_sld finished [took 56.3552s]
04/11/23 19:49:14| INFO bin_sld finished [took 193.2923s]
04/11/23 19:49:14| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 194.6251s]
04/11/23 19:49:14| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started
04/11/23 19:49:58| INFO ref finished [took 38.3754s]
04/11/23 19:50:02| INFO atc_mc finished [took 41.0091s]
04/11/23 19:50:03| INFO mulmc_sld finished [took 45.6205s]
04/11/23 19:50:05| INFO binmc_sld finished [took 46.1852s]
04/11/23 19:50:10| INFO mul_sld finished [took 52.9704s]
04/11/23 19:52:27| INFO bin_sld finished [took 190.6101s]
04/11/23 19:52:27| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 192.0378s]
04/11/23 19:52:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
04/11/23 19:53:10| INFO ref finished [took 38.4467s]
04/11/23 19:53:13| INFO atc_mc finished [took 41.2602s]
04/11/23 19:53:15| INFO mulmc_sld finished [took 45.7496s]
04/11/23 19:53:16| INFO binmc_sld finished [took 45.5531s]
04/11/23 19:53:21| INFO mul_sld finished [took 52.5067s]
04/11/23 19:55:38| INFO bin_sld finished [took 190.7744s]
04/11/23 19:55:38| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 191.9715s]
04/11/23 19:55:39| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
04/11/23 19:56:21| INFO ref finished [took 37.9420s]
04/11/23 19:56:26| INFO atc_mc finished [took 41.2056s]
04/11/23 19:56:27| INFO mulmc_sld finished [took 45.7577s]
04/11/23 19:56:28| INFO binmc_sld finished [took 45.6411s]
04/11/23 19:56:34| INFO mul_sld finished [took 53.5219s]
04/11/23 19:58:51| INFO bin_sld finished [took 191.1772s]
04/11/23 19:58:51| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 192.4566s]
04/11/23 19:58:51| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
04/11/23 19:59:34| INFO ref finished [took 37.8604s]
04/11/23 19:59:38| INFO atc_mc finished [took 41.0334s]
04/11/23 19:59:39| INFO mulmc_sld finished [took 45.1999s]
04/11/23 19:59:40| INFO binmc_sld finished [took 45.4846s]
04/11/23 19:59:47| INFO mul_sld finished [took 54.3166s]
04/11/23 20:02:04| INFO bin_sld finished [took 191.4002s]
04/11/23 20:02:04| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 192.6275s]
04/11/23 20:02:04| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
04/11/23 20:02:48| INFO ref finished [took 38.8313s]
04/11/23 20:02:52| INFO atc_mc finished [took 42.1162s]
04/11/23 20:02:54| INFO mulmc_sld finished [took 47.0413s]
04/11/23 20:02:55| INFO binmc_sld finished [took 46.8891s]
04/11/23 20:03:02| INFO mul_sld finished [took 55.8821s]
04/11/23 20:05:19| INFO bin_sld finished [took 193.7571s]
04/11/23 20:05:19| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 195.2404s]
04/11/23 20:05:19| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
04/11/23 20:06:03| INFO ref finished [took 38.7982s]
04/11/23 20:06:06| INFO atc_mc finished [took 41.6213s]
04/11/23 20:06:08| INFO mulmc_sld finished [took 46.2646s]
04/11/23 20:06:09| INFO binmc_sld finished [took 46.2453s]
04/11/23 20:06:16| INFO mul_sld finished [took 54.8621s]
04/11/23 20:08:35| INFO bin_sld finished [took 194.5226s]
04/11/23 20:08:35| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 195.9251s]
04/11/23 20:08:35| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
04/11/23 20:09:18| INFO ref finished [took 38.3873s]
04/11/23 20:09:22| INFO atc_mc finished [took 41.2537s]
04/11/23 20:09:24| INFO mulmc_sld finished [took 46.2211s]
04/11/23 20:09:25| INFO binmc_sld finished [took 46.6421s]
04/11/23 20:09:38| INFO mul_sld finished [took 60.9539s]
04/11/23 20:11:51| INFO bin_sld finished [took 195.1888s]
04/11/23 20:11:51| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 196.4776s]
----------------------------------------------------------------------------------------------------
04/11/23 20:56:32| INFO dataset rcv1_CCAT_9prevs
04/11/23 20:56:37| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs started
04/11/23 20:57:33| INFO ref finished [took 49.2697s]
04/11/23 20:57:38| INFO atc_mc finished [took 53.2068s]
04/11/23 20:57:39| INFO mulmc_sld finished [took 58.6224s]
04/11/23 20:58:59| INFO mulmc_sld_gs finished [took 136.0930s]
04/11/23 21:00:30| INFO binmc_sld finished [took 230.3290s]
04/11/23 21:02:12| INFO mul_sld_gs finished [took 333.4899s]
04/11/23 21:06:49| INFO bin_sld_gs finished [took 610.5751s]
04/11/23 21:06:54| INFO binmc_sld_gs finished [took 612.8900s]
04/11/23 21:06:55| INFO Dataset sample 0.10 of dataset rcv1_CCAT_9prevs finished [took 617.6873s]
04/11/23 21:06:55| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs started
04/11/23 21:07:52| INFO ref finished [took 49.8077s]
04/11/23 21:07:56| INFO atc_mc finished [took 53.3303s]
04/11/23 21:07:57| INFO mulmc_sld finished [took 58.9345s]
04/11/23 21:09:17| INFO mulmc_sld_gs finished [took 136.5258s]
04/11/23 21:10:51| INFO binmc_sld finished [took 233.4049s]
04/11/23 21:12:35| INFO mul_sld_gs finished [took 338.2751s]
04/11/23 21:17:38| INFO bin_sld_gs finished [took 641.8524s]
04/11/23 21:18:19| INFO binmc_sld_gs finished [took 679.9471s]
04/11/23 21:18:19| INFO Dataset sample 0.20 of dataset rcv1_CCAT_9prevs finished [took 684.7098s]
04/11/23 21:18:19| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs started
04/11/23 21:19:24| INFO ref finished [took 55.3767s]
04/11/23 21:19:28| INFO mulmc_sld finished [took 64.2789s]
04/11/23 21:19:29| INFO atc_mc finished [took 59.5610s]
04/11/23 21:20:57| INFO mulmc_sld_gs finished [took 150.1392s]
04/11/23 21:22:36| INFO binmc_sld finished [took 253.0960s]
04/11/23 21:24:16| INFO mul_sld_gs finished [took 354.6283s]
04/11/23 21:29:15| INFO bin_sld_gs finished [took 654.3325s]
04/11/23 21:29:50| INFO binmc_sld_gs finished [took 684.5074s]
04/11/23 21:29:50| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 690.4897s]
04/11/23 21:29:50| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
04/11/23 21:30:45| INFO ref finished [took 48.2647s]
04/11/23 21:30:51| INFO atc_mc finished [took 52.2724s]
04/11/23 21:30:51| INFO mulmc_sld finished [took 57.5142s]
04/11/23 21:32:07| INFO mulmc_sld_gs finished [took 131.4908s]
04/11/23 21:33:38| INFO binmc_sld finished [took 224.9620s]
04/11/23 21:35:22| INFO mul_sld_gs finished [took 329.9053s]
04/11/23 21:40:25| INFO bin_sld_gs finished [took 634.4342s]
04/11/23 21:41:08| INFO binmc_sld_gs finished [took 673.6071s]
04/11/23 21:41:08| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 678.4725s]
04/11/23 21:41:08| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
04/11/23 21:42:03| INFO ref finished [took 47.4381s]
04/11/23 21:42:08| INFO atc_mc finished [took 51.3566s]
04/11/23 21:42:09| INFO mulmc_sld finished [took 56.6180s]
04/11/23 21:43:23| INFO mulmc_sld_gs finished [took 128.6413s]
04/11/23 21:44:54| INFO binmc_sld finished [took 222.7951s]
04/11/23 21:46:39| INFO mul_sld_gs finished [took 328.8118s]
04/11/23 21:51:37| INFO bin_sld_gs finished [took 627.4937s]
04/11/23 21:52:17| INFO binmc_sld_gs finished [took 663.8116s]
04/11/23 21:52:17| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 668.8948s]
04/11/23 21:52:17| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
04/11/23 21:53:12| INFO ref finished [took 47.6269s]
04/11/23 21:53:16| INFO atc_mc finished [took 51.1109s]
04/11/23 21:53:17| INFO mulmc_sld finished [took 56.5728s]
04/11/23 21:54:31| INFO mulmc_sld_gs finished [took 128.0358s]
04/11/23 21:56:00| INFO binmc_sld finished [took 220.0811s]
04/11/23 21:57:46| INFO mul_sld_gs finished [took 327.0856s]
04/11/23 22:02:58| INFO bin_sld_gs finished [took 639.3432s]
04/11/23 22:03:48| INFO binmc_sld_gs finished [took 686.2326s]
04/11/23 22:03:48| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 690.9677s]
04/11/23 22:03:48| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
04/11/23 22:04:42| INFO ref finished [took 47.2804s]
04/11/23 22:04:48| INFO atc_mc finished [took 51.6888s]
04/11/23 22:04:48| INFO mulmc_sld finished [took 56.1465s]
04/11/23 22:06:06| INFO mulmc_sld_gs finished [took 132.4278s]
04/11/23 22:07:33| INFO binmc_sld finished [took 221.9299s]
04/11/23 22:09:19| INFO mul_sld_gs finished [took 329.1446s]
04/11/23 22:14:09| INFO bin_sld_gs finished [took 619.3584s]
04/11/23 22:14:32| INFO binmc_sld_gs finished [took 638.7326s]
04/11/23 22:14:32| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 643.6278s]
04/11/23 22:14:32| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
04/11/23 22:15:26| INFO ref finished [took 47.3139s]
04/11/23 22:15:30| INFO atc_mc finished [took 50.8602s]
04/11/23 22:15:32| INFO mulmc_sld finished [took 56.5107s]
04/11/23 22:16:47| INFO mulmc_sld_gs finished [took 129.5292s]
04/11/23 22:18:22| INFO binmc_sld finished [took 226.9238s]
04/11/23 22:20:02| INFO mul_sld_gs finished [took 327.7014s]
04/11/23 22:24:57| INFO bin_sld_gs finished [took 624.4254s]
04/11/23 22:25:13| INFO binmc_sld_gs finished [took 636.2675s]
04/11/23 22:25:13| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 641.0382s]
04/11/23 22:25:13| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
04/11/23 22:26:07| INFO ref finished [took 47.3224s]
04/11/23 22:26:12| INFO atc_mc finished [took 51.1828s]
04/11/23 22:26:13| INFO mulmc_sld finished [took 56.6133s]
04/11/23 22:27:30| INFO mulmc_sld_gs finished [took 131.3662s]
04/11/23 22:29:05| INFO binmc_sld finished [took 229.3002s]
04/11/23 22:30:38| INFO mul_sld_gs finished [took 323.5271s]
04/11/23 22:35:21| INFO bin_sld_gs finished [took 606.6430s]
04/11/23 22:35:30| INFO binmc_sld_gs finished [took 612.5966s]
04/11/23 22:35:30| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 617.3109s]
----------------------------------------------------------------------------------------------------
04/11/23 22:49:37| ERROR Evaluation over rcv1_CCAT_3prevs failed. Exception: 'Invalid estimator: estimator binmc_sld_gs does not exist'
04/11/23 22:49:37| ERROR Failed while saving configuration rcv1_CCAT_debug of rcv1_CCAT_3prevs. Exception: cannot access local variable 'dr' where it is not associated with a value
----------------------------------------------------------------------------------------------------
04/11/23 22:50:07| INFO dataset rcv1_CCAT_3prevs
04/11/23 22:50:12| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started
----------------------------------------------------------------------------------------------------
04/11/23 22:55:55| INFO dataset rcv1_CCAT_3prevs
04/11/23 22:55:59| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started
04/11/23 22:56:48| INFO ref finished [took 44.4275s]
----------------------------------------------------------------------------------------------------
04/11/23 22:56:59| INFO dataset rcv1_CCAT_3prevs
04/11/23 22:57:03| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started
04/11/23 22:57:09| WARNING Method mul_sld_gs failed. Exception: '>=' not supported between instances of 'TypeError' and 'int'
04/11/23 22:57:17| WARNING Method bin_sld_gs failed. Exception: '>=' not supported between instances of 'TypeError' and 'int'
----------------------------------------------------------------------------------------------------
04/11/23 22:58:04| INFO dataset rcv1_CCAT_3prevs
04/11/23 22:58:09| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started
04/11/23 22:58:58| INFO ref finished [took 43.7541s]
04/11/23 22:59:05| INFO atc_mc finished [took 50.0628s]
----------------------------------------------------------------------------------------------------
04/11/23 23:01:22| INFO dataset rcv1_CCAT_3prevs
04/11/23 23:01:27| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started
04/11/23 23:02:16| INFO ref finished [took 43.9765s]
04/11/23 23:02:23| INFO atc_mc finished [took 50.5568s]
----------------------------------------------------------------------------------------------------
04/11/23 23:09:33| INFO dataset rcv1_CCAT_3prevs
04/11/23 23:09:37| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started
04/11/23 23:09:38| WARNING Method binmc_sld failed. Exception: classifier and pred_proba cannot be both None
04/11/23 23:09:39| WARNING Method mulmc_sld failed. Exception: classifier and pred_proba cannot be both None
04/11/23 23:09:40| WARNING Method bin_sld_gs failed. Exception: no combination of hyperparameters seem to work
04/11/23 23:09:41| WARNING Method mul_sld_gs failed. Exception: no combination of hyperparameters seem to work
----------------------------------------------------------------------------------------------------
04/11/23 23:10:23| INFO dataset rcv1_CCAT_3prevs
04/11/23 23:10:28| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started
04/11/23 23:11:15| INFO ref finished [took 42.4887s]
04/11/23 23:11:20| INFO atc_mc finished [took 45.6262s]
04/11/23 23:11:21| INFO mulmc_sld finished [took 50.9790s]
04/11/23 23:13:57| INFO binmc_sld finished [took 208.3159s]
----------------------------------------------------------------------------------------------------
04/11/23 23:16:22| INFO dataset rcv1_CCAT_3prevs
04/11/23 23:16:26| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started
04/11/23 23:17:12| INFO ref finished [took 40.5978s]
04/11/23 23:17:16| INFO atc_mc finished [took 43.6933s]
04/11/23 23:17:17| INFO mulmc_sld finished [took 49.0808s]
04/11/23 23:19:53| INFO binmc_sld finished [took 205.5731s]
04/11/23 23:22:24| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00672) [took 354.1411s]
04/11/23 23:23:05| INFO mul_sld_gs finished [took 394.8240s]
04/11/23 23:30:41| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': 'vs', 'confidence': None} (score=0.00891) [took 852.1465s]
04/11/23 23:33:44| INFO bin_sld_gs finished [took 1035.2071s]
04/11/23 23:33:44| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs finished [took 1038.1845s]
04/11/23 23:33:44| INFO Dataset sample 0.50 of dataset rcv1_CCAT_3prevs started
04/11/23 23:34:33| INFO ref finished [took 43.6409s]
04/11/23 23:34:37| INFO atc_mc finished [took 46.7818s]
04/11/23 23:34:38| INFO mulmc_sld finished [took 51.3459s]
04/11/23 23:37:15| INFO binmc_sld finished [took 209.5746s]
04/11/23 23:39:48| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00553) [took 359.3210s]
04/11/23 23:40:28| INFO mul_sld_gs finished [took 399.5320s]
04/11/23 23:48:02| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': None} (score=0.01058) [took 855.1289s]
04/11/23 23:51:06| INFO bin_sld_gs finished [took 1038.6344s]
04/11/23 23:51:06| INFO Dataset sample 0.50 of dataset rcv1_CCAT_3prevs finished [took 1041.6478s]
04/11/23 23:51:06| INFO Dataset sample 0.60 of dataset rcv1_CCAT_3prevs started
04/11/23 23:51:51| INFO ref finished [took 40.0694s]
04/11/23 23:51:55| INFO atc_mc finished [took 42.4882s]
04/11/23 23:51:56| INFO mulmc_sld finished [took 47.7936s]
04/11/23 23:54:29| INFO binmc_sld finished [took 201.3777s]
04/11/23 23:57:03| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00429) [took 352.7820s]
04/11/23 23:57:43| INFO mul_sld_gs finished [took 392.5201s]
05/11/23 00:05:21| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1000.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': None, 'confidence': None} (score=0.00552) [took 851.9361s]
05/11/23 00:08:24| INFO bin_sld_gs finished [took 1034.7353s]
05/11/23 00:08:24| INFO Dataset sample 0.60 of dataset rcv1_CCAT_3prevs finished [took 1037.8033s]
----------------------------------------------------------------------------------------------------
05/11/23 00:11:07| INFO dataset rcv1_CCAT_3prevs
05/11/23 00:11:12| INFO Dataset sample 0.40 of dataset rcv1_CCAT_3prevs started

View File

@ -1,9 +1,10 @@
import math
from typing import List, Optional
import numpy as np
import math
import scipy.sparse as sp
from quapy.data import LabelledCollection
from sklearn.base import BaseEstimator
# Extended classes
@ -128,8 +129,22 @@ class ExtendedCollection(LabelledCollection):
@classmethod
def extend_collection(
cls, base: LabelledCollection, pred_proba: np.ndarray
cls,
base: LabelledCollection,
classifier: BaseEstimator = None,
pred_proba: np.ndarray = None,
):
if classifier is None and pred_proba is None:
raise AttributeError("classifier and pred_proba cannot be both None")
if classifier is not None and pred_proba is not None:
raise AttributeError(
"Not needed parameters: just one of classifier or pred_proba is needed"
)
if classifier:
pred_proba = classifier.predict_proba(base.X)
n_classes = base.n_classes
# n_X = [ X | predicted probs. ]
@ -145,4 +160,3 @@ class ExtendedCollection(LabelledCollection):
)
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])

View File

@ -8,7 +8,7 @@ from sklearn.linear_model import LogisticRegression
import quacc as qc
from quacc.evaluation.report import EvaluationReport
from quacc.method.model_selection import GridSearchAE
from quacc.method.model_selection import BQAEgsq, GridSearchAE, MCAEgsq
from ..method.base import BQAE, MCAE, BaseAccuracyEstimator
@ -49,8 +49,7 @@ def evaluation_report(
@method
def bin_sld(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(c_model, SLD(LogisticRegression()))
est.fit(validation)
est = BQAE(c_model, SLD(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
@ -59,8 +58,7 @@ def bin_sld(c_model, validation, protocol) -> EvaluationReport:
@method
def mul_sld(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(c_model, SLD(LogisticRegression()))
est.fit(validation)
est = MCAE(c_model, SLD(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
@ -68,9 +66,12 @@ def mul_sld(c_model, validation, protocol) -> EvaluationReport:
@method
def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(c_model, SLD(LogisticRegression(), recalib="bcts"))
est.fit(validation)
def binmc_sld(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(
c_model,
SLD(LogisticRegression()),
confidence="max_conf",
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
@ -78,9 +79,12 @@ def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
@method
def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(c_model, SLD(LogisticRegression(), recalib="bcts"))
est.fit(validation)
def mulmc_sld(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(
c_model,
SLD(LogisticRegression()),
confidence="max_conf",
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
@ -97,10 +101,11 @@ def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts", "vs"],
"confidence": [None, "max_conf"],
},
refit=False,
protocol=UPP(v_val, repeats=100),
verbose=False,
verbose=True,
).fit(v_train)
return evaluation_report(
estimator=est,
@ -118,10 +123,11 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts", "vs"],
"confidence": [None, "max_conf"],
},
refit=False,
protocol=UPP(v_val, repeats=100),
verbose=False,
verbose=True,
).fit(v_train)
return evaluation_report(
estimator=est,
@ -129,10 +135,47 @@ def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
)
@method
def bin_sld_gsq(c_model, validation, protocol) -> EvaluationReport:
est = BQAEgsq(
c_model,
SLD(LogisticRegression()),
param_grid={
"classifier__C": np.logspace(-3, 3, 7),
"classifier__class_weight": [None, "balanced"],
"recalib": [None, "bcts", "vs"],
},
refit=False,
verbose=False,
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_sld_gsq(c_model, validation, protocol) -> EvaluationReport:
est = MCAEgsq(
c_model,
SLD(LogisticRegression()),
param_grid={
"classifier__C": np.logspace(-3, 3, 7),
"classifier__class_weight": [None, "balanced"],
"recalib": [None, "bcts", "vs"],
},
refit=False,
verbose=False,
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def bin_pacc(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(c_model, PACC(LogisticRegression(), recalib="bcts"))
est.fit(validation)
est = BQAE(c_model, PACC(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
@ -141,8 +184,7 @@ def bin_pacc(c_model, validation, protocol) -> EvaluationReport:
@method
def mul_pacc(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(c_model, PACC(LogisticRegression(), recalib="bcts"))
est.fit(validation)
est = MCAE(c_model, PACC(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
@ -158,7 +200,6 @@ def bin_pacc_gs(c_model, validation, protocol) -> EvaluationReport:
param_grid={
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts", "vs"],
},
refit=False,
protocol=UPP(v_val, repeats=100),
@ -179,7 +220,6 @@ def mul_pacc_gs(c_model, validation, protocol) -> EvaluationReport:
param_grid={
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts", "vs"],
},
refit=False,
protocol=UPP(v_val, repeats=100),
@ -188,5 +228,4 @@ def mul_pacc_gs(c_model, validation, protocol) -> EvaluationReport:
return evaluation_report(
estimator=est,
protocol=protocol,
method_name="bin_sld_gs",
)

View File

@ -27,7 +27,7 @@ def estimate_worker(_estimate, train, validation, test, _env=None, q=None):
result = _estimate(model, validation, protocol)
except Exception as e:
log.warning(f"Method {_estimate.__name__} failed. Exception: {e}")
# traceback(e)
traceback(e)
return {
"name": _estimate.__name__,
"result": None,

View File

@ -17,9 +17,11 @@ class BaseAccuracyEstimator(BaseQuantifier):
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
confidence=None,
):
self.__check_classifier(classifier)
self.quantifier = quantifier
self.confidence = confidence
def __check_classifier(self, classifier):
if not hasattr(classifier, "predict_proba"):
@ -28,10 +30,37 @@ class BaseAccuracyEstimator(BaseQuantifier):
)
self.classifier = classifier
def __get_confidence(self):
if self.confidence is None:
return None
__confs = {
"max_conf": lambda probas: np.max(probas, axis=-1).reshape((len(probas), 1))
}
return __confs.get(self.confidence, None)
def __get_ext(self, pred_proba):
_ext = pred_proba
_f_conf = self.__get_confidence()
if _f_conf is not None:
_confs = _f_conf(pred_proba)
_ext = np.concatenate((_confs, pred_proba), axis=1)
return _ext
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if not pred_proba:
if pred_proba is None:
pred_proba = self.classifier.predict_proba(coll.X)
return ExtendedCollection.extend_collection(coll, pred_proba)
_ext = self.__get_ext(pred_proba)
return ExtendedCollection.extend_collection(coll, pred_proba=_ext)
def _extend_instances(self, instances: np.ndarray | csr_matrix, pred_proba=None):
if pred_proba is None:
pred_proba = self.classifier.predict_proba(instances)
_ext = self.__get_ext(pred_proba)
return ExtendedCollection.extend_instances(instances, _ext)
@abstractmethod
def fit(self, train: LabelledCollection | ExtendedCollection):
@ -47,23 +76,24 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
confidence: str = None,
):
super().__init__(classifier, quantifier)
super().__init__(
classifier=classifier,
quantifier=quantifier,
confidence=confidence,
)
self.e_train = None
def fit(self, train: LabelledCollection):
pred_probs = self.classifier.predict_proba(train.X)
self.e_train = ExtendedCollection.extend_collection(train, pred_probs)
self.e_train = self.extend(train)
self.quantifier.fit(self.e_train)
return self
def estimate(self, instances, ext=False) -> np.ndarray:
e_inst = instances
if not ext:
pred_prob = self.classifier.predict_proba(instances)
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
e_inst = instances if ext else self._extend_instances(instances)
estim_prev = self.quantifier.quantify(e_inst)
return self._check_prevalence_classes(estim_prev)
@ -78,18 +108,25 @@ class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
def __init__(self, classifier: BaseEstimator, quantifier: BaseAccuracyEstimator):
super().__init__(classifier, quantifier)
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
confidence: str = None,
):
super().__init__(
classifier=classifier,
quantifier=quantifier,
confidence=confidence,
)
self.quantifiers = []
self.e_trains = []
def fit(self, train: LabelledCollection | ExtendedCollection):
pred_probs = self.classifier.predict_proba(train.X)
self.e_train = ExtendedCollection.extend_collection(train, pred_probs)
self.e_train = self.extend(train)
self.n_classes = self.e_train.n_classes
self.e_trains = self.e_train.split_by_pred()
self.quantifiers = [deepcopy(self.quantifier) for _ in self.e_trains]
self.quantifiers = []
for train in self.e_trains:
@ -97,12 +134,11 @@ class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
quant.fit(train)
self.quantifiers.append(quant)
return self
def estimate(self, instances, ext=False):
# TODO: test
e_inst = instances
if not ext:
pred_prob = self.classifier.predict_proba(instances)
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
e_inst = instances if ext else self._extend_instances(instances)
_ncl = int(math.sqrt(self.n_classes))
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)

View File

@ -3,14 +3,22 @@ from copy import deepcopy
from time import time
from typing import Callable, Union
import quapy as qp
from quapy.data import LabelledCollection
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
from sklearn.base import BaseEstimator
import quacc as qc
import quacc.error
from quacc.data import ExtendedCollection
from quacc.evaluation import evaluate
from quacc.method.base import BaseAccuracyEstimator
from quacc.logger import SubLogger
from quacc.method.base import (
BaseAccuracyEstimator,
BinaryQuantifierAccuracyEstimator,
MultiClassAccuracyEstimator,
)
class GridSearchAE(BaseAccuracyEstimator):
@ -106,6 +114,12 @@ class GridSearchAE(BaseAccuracyEstimator):
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
f"[took {tend:.4f}s]"
)
log = SubLogger.logger()
log.debug(
f"[{self.model.__class__.__name__}] "
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
f"[took {tend:.4f}s]"
)
if self.refit:
if isinstance(protocol, OnLabelledCollectionProtocol):
@ -203,3 +217,84 @@ class GridSearchAE(BaseAccuracyEstimator):
if hasattr(self, "best_model_"):
return self.best_model_
raise ValueError("best_model called before fit")
class MCAEgsq(MultiClassAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
param_grid: dict,
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
verbose=False,
):
self.param_grid = param_grid
self.refit = refit
self.timeout = timeout
self.n_jobs = n_jobs
self.verbose = verbose
self.error = error
super().__init__(classifier, quantifier)
def fit(self, train: LabelledCollection):
self.e_train = self.extend(train)
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
self.quantifier = GridSearchQ(
deepcopy(self.quantifier),
param_grid=self.param_grid,
protocol=UPP(t_val, repeats=100),
error=self.error,
refit=self.refit,
timeout=self.timeout,
n_jobs=self.n_jobs,
verbose=self.verbose,
).fit(self.e_train)
return self
class BQAEgsq(BinaryQuantifierAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
param_grid: dict,
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
verbose=False,
):
self.param_grid = param_grid
self.refit = refit
self.timeout = timeout
self.n_jobs = n_jobs
self.verbose = verbose
self.error = error
super().__init__(classifier=classifier, quantifier=quantifier)
def fit(self, train: LabelledCollection):
self.e_train = self.extend(train)
self.n_classes = self.e_train.n_classes
self.e_trains = self.e_train.split_by_pred()
self.quantifiers = []
for e_train in self.e_trains:
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
quantifier = GridSearchQ(
model=deepcopy(self.quantifier),
param_grid=self.param_grid,
protocol=UPP(t_val, repeats=100),
error=self.error,
refit=self.refit,
timeout=self.timeout,
n_jobs=self.n_jobs,
verbose=self.verbose,
).fit(t_train)
self.quantifiers.append(quantifier)
return self