diff --git a/TODO.html b/TODO.html index 31b1c20..2eaf501 100644 --- a/TODO.html +++ b/TODO.html @@ -103,15 +103,6 @@ verbose=True).fit(V_tr)

import baselines

  • -

    plot avg con train prevalence sull'asse x e media su test prevalecne

    -
  • -
  • -

    realizzare grid search per task specifico partendo da GridSearchQ

    -
  • -
  • -

    provare PACC come quantificatore

    -
  • -
  • importare mandoline

  • +

    plot avg con train prevalence sull'asse x e media su test prevalecne

    +
  • +
  • +

    realizzare grid search per task specifico partendo da GridSearchQ

    +
  • +
  • +

    provare PACC come quantificatore

    +
  • +
  • aggiungere etichette in shift plot

  • diff --git a/TODO.md b/TODO.md index 028e10e..f78d3ad 100644 --- a/TODO.md +++ b/TODO.md @@ -30,13 +30,13 @@ - nel caso di bin fare media dei due best score - [x] import baselines -- [x] plot avg con train prevalence sull'asse x e media su test prevalecne -- [x] realizzare grid search per task specifico partendo da GridSearchQ -- [x] provare PACC come quantificatore - [ ] importare mandoline - mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc - [ ] sistemare vecchie iw baselines - non possono essere fixate perché dipendono da numpy +- [x] plot avg con train prevalence sull'asse x e media su test prevalecne +- [x] realizzare grid search per task specifico partendo da GridSearchQ +- [x] provare PACC come quantificatore - [ ] aggiungere etichette in shift plot - [ ] sistemare exact_train quapy - [ ] testare anche su imbd \ No newline at end of file diff --git a/conf.yaml b/conf.yaml index 73fe841..073e0e7 100644 --- a/conf.yaml +++ b/conf.yaml @@ -151,4 +151,4 @@ main_conf: &main_conf - atc_ne - doc_feat -exec: *mc_conf \ No newline at end of file +exec: *debug_conf \ No newline at end of file diff --git a/quacc.log b/quacc.log index c45692f..62b3787 100644 --- a/quacc.log +++ b/quacc.log @@ -2850,3 +2850,154 @@ 05/11/23 14:16:15| INFO atc_mc finished [took 49.6779s] 05/11/23 14:16:19| INFO mulmc_sld finished [took 61.0610s] 05/11/23 14:16:22| INFO mulne_sld finished [took 62.2089s] +05/11/23 14:19:02| INFO binmc_sld finished [took 225.5737s] +05/11/23 14:19:03| INFO binne_sld finished [took 223.9017s] +05/11/23 14:28:50| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00756) [took 806.7930s] +05/11/23 14:29:32| INFO mul_sld_gs finished [took 848.7630s] +05/11/23 14:36:02| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00781) [took 1240.9138s] +05/11/23 14:39:04| INFO bin_sld_gs finished [took 1422.5520s] +05/11/23 14:39:04| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 1428.8824s] +05/11/23 14:39:04| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started +05/11/23 14:39:58| INFO ref finished [took 45.7514s] +05/11/23 14:40:02| INFO atc_mc finished [took 48.3888s] +05/11/23 14:40:05| INFO mulmc_sld finished [took 59.0537s] +05/11/23 14:40:09| INFO mulne_sld finished [took 60.9189s] +05/11/23 14:42:42| INFO binne_sld finished [took 214.5464s] +05/11/23 14:42:44| INFO binmc_sld finished [took 218.8429s] +05/11/23 14:52:23| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1000.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00984) [took 792.5474s] +05/11/23 14:53:05| INFO mul_sld_gs finished [took 834.1824s] +05/11/23 14:59:56| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.01112) [took 1247.0092s] +05/11/23 15:02:57| INFO bin_sld_gs finished [took 1427.5051s] +05/11/23 15:02:57| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 1432.9172s] +05/11/23 15:02:57| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started +05/11/23 15:03:49| INFO ref finished [took 44.4148s] +05/11/23 15:03:54| INFO atc_mc finished [took 47.7566s] +05/11/23 15:04:00| INFO mulmc_sld finished [took 60.5480s] +05/11/23 15:04:03| INFO mulne_sld finished [took 61.2226s] +05/11/23 15:06:30| INFO binmc_sld finished [took 211.9647s] +05/11/23 15:06:32| INFO binne_sld finished [took 211.4312s] +05/11/23 15:16:00| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.00571) [took 776.6085s] +05/11/23 15:16:42| INFO mul_sld_gs finished [took 817.9358s] +05/11/23 15:23:24| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00653) [took 1221.6531s] +05/11/23 15:26:23| INFO bin_sld_gs finished [took 1400.9688s] +05/11/23 15:26:23| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 1406.4620s] +05/11/23 15:26:23| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started +05/11/23 15:27:16| INFO ref finished [took 44.3988s] +05/11/23 15:27:21| INFO atc_mc finished [took 48.5589s] +05/11/23 15:27:27| INFO mulmc_sld finished [took 61.4269s] +05/11/23 15:27:29| INFO mulne_sld finished [took 61.8292s] +05/11/23 15:29:55| INFO binmc_sld finished [took 210.1585s] +05/11/23 15:29:59| INFO binne_sld finished [took 212.0930s] +05/11/23 15:39:22| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.00616) [took 771.6071s] +05/11/23 15:40:03| INFO mul_sld_gs finished [took 813.2905s] +05/11/23 15:47:04| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': None, 'confidence': None} (score=0.00544) [took 1234.9832s] +05/11/23 15:50:10| INFO bin_sld_gs finished [took 1421.7775s] +05/11/23 15:50:10| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 1427.0062s] +05/11/23 15:50:10| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started +05/11/23 15:51:11| INFO ref finished [took 49.7682s] +05/11/23 15:51:19| INFO atc_mc finished [took 54.2855s] +05/11/23 15:51:22| INFO mulmc_sld finished [took 68.7688s] +05/11/23 15:51:26| INFO mulne_sld finished [took 69.3711s] +05/11/23 15:54:07| INFO binmc_sld finished [took 234.7962s] +05/11/23 15:54:09| INFO binne_sld finished [took 234.6444s] +05/11/23 16:03:51| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 0.1, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'bcts', 'confidence': 'entropy'} (score=0.00765) [took 811.6704s] +05/11/23 16:04:34| INFO mul_sld_gs finished [took 854.8196s] +05/11/23 16:11:10| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 0.1, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.01234) [took 1252.4784s] +05/11/23 16:14:10| INFO bin_sld_gs finished [took 1431.7446s] +05/11/23 16:14:10| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 1439.1145s] +05/11/23 16:14:10| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started +05/11/23 16:15:02| INFO ref finished [took 44.0970s] +05/11/23 16:15:07| INFO atc_mc finished [took 48.2871s] +05/11/23 16:15:13| INFO mulmc_sld finished [took 61.0461s] +05/11/23 16:15:15| INFO mulne_sld finished [took 60.6375s] +05/11/23 16:17:46| INFO binmc_sld finished [took 215.1734s] +05/11/23 16:17:49| INFO binne_sld finished [took 215.7846s] +05/11/23 16:27:15| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': None} (score=0.00822) [took 778.5688s] +05/11/23 16:27:56| INFO mul_sld_gs finished [took 819.2615s] +05/11/23 16:34:16| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'entropy'} (score=0.00894) [took 1200.6639s] +05/11/23 16:37:21| INFO bin_sld_gs finished [took 1385.9035s] +05/11/23 16:37:21| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 1391.5055s] +05/11/23 16:37:21| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started +05/11/23 16:38:13| INFO ref finished [took 44.7046s] +05/11/23 16:38:18| INFO atc_mc finished [took 48.7802s] +05/11/23 16:38:21| INFO mulmc_sld finished [took 57.4163s] +05/11/23 16:38:24| INFO mulne_sld finished [took 58.9847s] +05/11/23 16:40:59| INFO binmc_sld finished [took 216.7311s] +05/11/23 16:41:01| INFO binne_sld finished [took 216.5312s] +05/11/23 16:50:06| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00808) [took 758.6896s] +05/11/23 16:50:46| INFO mul_sld_gs finished [took 798.8038s] +05/11/23 16:56:41| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'entropy'} (score=0.00604) [took 1154.7043s] +05/11/23 16:59:39| INFO bin_sld_gs finished [took 1332.5521s] +05/11/23 16:59:39| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 1337.7947s] +---------------------------------------------------------------------------------------------------- +05/11/23 20:08:46| ERROR estimate comparison failed. Exceprion: 'environ' object has no attribute 'OUT_PATH' +---------------------------------------------------------------------------------------------------- +05/11/23 20:09:08| ERROR estimate comparison failed. Exceprion: 'environ' object has no attribute 'OUT_PATH' +---------------------------------------------------------------------------------------------------- +05/11/23 20:09:27| INFO dataset imdb_3prevs +05/11/23 20:09:34| INFO Dataset sample 0.20 of dataset imdb_3prevs started +05/11/23 20:09:44| INFO ref finished [took 8.9550s] +05/11/23 20:09:47| INFO atc_mc finished [took 11.8923s] +05/11/23 20:09:56| INFO mulmc_sld finished [took 21.3196s] +05/11/23 20:09:56| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 21.7709s] +05/11/23 20:09:56| INFO Dataset sample 0.50 of dataset imdb_3prevs started +05/11/23 20:10:05| INFO ref finished [took 8.6116s] +05/11/23 20:10:08| INFO atc_mc finished [took 11.6880s] +05/11/23 20:10:16| INFO mulmc_sld finished [took 19.7793s] +05/11/23 20:10:16| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.3246s] +05/11/23 20:10:16| INFO Dataset sample 0.80 of dataset imdb_3prevs started +05/11/23 20:10:26| INFO ref finished [took 8.6654s] +05/11/23 20:10:29| INFO atc_mc finished [took 11.6975s] +05/11/23 20:10:35| INFO mulmc_sld finished [took 18.1478s] +05/11/23 20:10:35| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.7200s] +---------------------------------------------------------------------------------------------------- +05/11/23 20:11:42| INFO dataset imdb_3prevs +05/11/23 20:11:49| INFO Dataset sample 0.20 of dataset imdb_3prevs started +05/11/23 20:11:58| INFO ref finished [took 8.7146s] +05/11/23 20:12:02| INFO atc_mc finished [took 11.9672s] +05/11/23 20:12:10| INFO mulmc_sld finished [took 20.7824s] +05/11/23 20:12:10| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 21.2293s] +05/11/23 20:12:10| INFO Dataset sample 0.50 of dataset imdb_3prevs started +05/11/23 20:12:19| INFO ref finished [took 8.5867s] +05/11/23 20:12:23| INFO atc_mc finished [took 11.6542s] +05/11/23 20:12:30| INFO mulmc_sld finished [took 19.6709s] +05/11/23 20:12:30| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.1802s] +05/11/23 20:12:30| INFO Dataset sample 0.80 of dataset imdb_3prevs started +05/11/23 20:12:40| INFO ref finished [took 8.7231s] +05/11/23 20:12:43| INFO atc_mc finished [took 11.8244s] +05/11/23 20:12:49| INFO mulmc_sld finished [took 18.0420s] +05/11/23 20:12:49| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.6102s] +---------------------------------------------------------------------------------------------------- +05/11/23 20:14:32| INFO dataset imdb_3prevs +05/11/23 20:14:39| INFO Dataset sample 0.20 of dataset imdb_3prevs started +05/11/23 20:14:48| INFO ref finished [took 8.6247s] +05/11/23 20:14:51| INFO atc_mc finished [took 11.6363s] +05/11/23 20:15:00| INFO mulmc_sld finished [took 20.4634s] +05/11/23 20:15:00| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 20.9026s] +05/11/23 20:15:00| INFO Dataset sample 0.50 of dataset imdb_3prevs started +05/11/23 20:15:09| INFO ref finished [took 8.5219s] +05/11/23 20:15:12| INFO atc_mc finished [took 11.6739s] +05/11/23 20:15:20| INFO mulmc_sld finished [took 19.8454s] +05/11/23 20:15:20| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.3705s] +05/11/23 20:15:20| INFO Dataset sample 0.80 of dataset imdb_3prevs started +05/11/23 20:15:29| INFO ref finished [took 8.5948s] +05/11/23 20:15:32| INFO atc_mc finished [took 11.7465s] +05/11/23 20:15:39| INFO mulmc_sld finished [took 17.9276s] +05/11/23 20:15:39| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.4893s] +---------------------------------------------------------------------------------------------------- +05/11/23 20:16:10| INFO dataset imdb_3prevs +05/11/23 20:16:17| INFO Dataset sample 0.20 of dataset imdb_3prevs started +05/11/23 20:16:26| INFO ref finished [took 8.3736s] +05/11/23 20:16:29| INFO atc_mc finished [took 11.3995s] +05/11/23 20:16:38| INFO mulmc_sld finished [took 20.4916s] +05/11/23 20:16:38| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 20.9187s] +05/11/23 20:16:38| INFO Dataset sample 0.50 of dataset imdb_3prevs started +05/11/23 20:16:47| INFO ref finished [took 8.4368s] +05/11/23 20:16:50| INFO atc_mc finished [took 11.4889s] +05/11/23 20:16:58| INFO mulmc_sld finished [took 19.6803s] +05/11/23 20:16:58| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.2091s] +05/11/23 20:16:58| INFO Dataset sample 0.80 of dataset imdb_3prevs started +05/11/23 20:17:08| INFO ref finished [took 8.9281s] +05/11/23 20:17:11| INFO atc_mc finished [took 11.9333s] +05/11/23 20:17:17| INFO mulmc_sld finished [took 18.2367s] +05/11/23 20:17:17| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.8309s] diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index 8b74071..7421a2b 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -182,22 +182,21 @@ class CompReport: train_prev=self.train_prev, ) elif mode == "shift": - shift_data = ( - self.shift_data(metric=metric, estimators=estimators) - .groupby(level=0) - .mean() - ) + _shift_data = self.shift_data(metric=metric, estimators=estimators) + shift_avg = _shift_data.groupby(level=0).mean() + shift_counts = _shift_data.groupby(level=0).count() shift_prevs = np.around( - [(1.0 - p, p) for p in np.sort(shift_data.index.unique(0))], + [(1.0 - p, p) for p in np.sort(shift_avg.index.unique(0))], decimals=2, ) return plot.plot_shift( shift_prevs=shift_prevs, - columns=shift_data.columns.to_numpy(), - data=shift_data.T.to_numpy(), + columns=shift_avg.columns.to_numpy(), + data=shift_avg.T.to_numpy(), metric=metric, name=conf, train_prev=self.train_prev, + counts=shift_counts.T.to_numpy(), ) def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str: @@ -374,6 +373,7 @@ class DatasetReport: res += "### avg dataset shift\n" avg_shift = _shift_data.groupby(level=0).mean() + count_shift = _shift_data.groupby(level=0).count() prevs_shift = np.sort(avg_shift.index.unique(0)) shift_op = plot.plot_shift( @@ -383,6 +383,7 @@ class DatasetReport: metric=metric, name=conf, train_prev=None, + counts=count_shift.T.to_numpy(), ) res += f"![plot_shift]({shift_op.relative_to(env.OUT_DIR).as_posix()})\n" diff --git a/quacc/plot.py b/quacc/plot.py index f34fb04..4ccb3fe 100644 --- a/quacc/plot.py +++ b/quacc/plot.py @@ -27,7 +27,6 @@ def plot_delta( metric="acc", name="default", train_prev=None, - fit_scores=None, legend=True, avg=None, ) -> Path: @@ -75,14 +74,6 @@ def plot_delta( color=_cy["color"], alpha=0.25, ) - if fit_scores is not None and method in fit_scores: - ax.plot( - base_prevs, - np.repeat(fit_scores[method], base_prevs.shape[0]), - color=_cy["color"], - linestyle="--", - markersize=0, - ) x_label = "test" if avg is None or avg == "train" else "train" ax.set( @@ -188,11 +179,11 @@ def plot_shift( columns, data, *, + counts=None, pos_class=1, metric="acc", name="default", train_prev=None, - fit_scores=None, legend=True, ) -> Path: if train_prev is not None: @@ -223,15 +214,20 @@ def plot_shift( markersize=3, zorder=2, ) - - if fit_scores is not None and method in fit_scores: - ax.plot( - shift_prevs, - np.repeat(fit_scores[method], shift_prevs.shape[0]), - color=_cy["color"], - linestyle="--", - markersize=0, - ) + if counts is not None: + _col_idx = np.where(columns == method)[0] + count = counts[_col_idx].flatten() + for prev, shift, cnt in zip(shift_prevs, shifts, count): + label = f"{cnt}" + plt.annotate( + label, + (prev, shift), + textcoords="offset points", + xytext=(0, 10), + ha="center", + color=_cy["color"], + fontsize=12.0, + ) ax.set(xlabel="dataset shift", ylabel=metric, title=title)