diff --git a/TODO.html b/TODO.html
index 31b1c20..2eaf501 100644
--- a/TODO.html
+++ b/TODO.html
@@ -103,15 +103,6 @@ verbose=True).fit(V_tr)
import baselines
- plot avg con train prevalence sull'asse x e media su test prevalecne
-
-
- realizzare grid search per task specifico partendo da GridSearchQ
-
-
- provare PACC come quantificatore
-
-
importare mandoline
- mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc
@@ -124,6 +115,15 @@ verbose=True).fit(V_tr)
+ plot avg con train prevalence sull'asse x e media su test prevalecne
+
+
+ realizzare grid search per task specifico partendo da GridSearchQ
+
+
+ provare PACC come quantificatore
+
+
aggiungere etichette in shift plot
diff --git a/TODO.md b/TODO.md
index 028e10e..f78d3ad 100644
--- a/TODO.md
+++ b/TODO.md
@@ -30,13 +30,13 @@
- nel caso di bin fare media dei due best score
- [x] import baselines
-- [x] plot avg con train prevalence sull'asse x e media su test prevalecne
-- [x] realizzare grid search per task specifico partendo da GridSearchQ
-- [x] provare PACC come quantificatore
- [ ] importare mandoline
- mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc
- [ ] sistemare vecchie iw baselines
- non possono essere fixate perché dipendono da numpy
+- [x] plot avg con train prevalence sull'asse x e media su test prevalecne
+- [x] realizzare grid search per task specifico partendo da GridSearchQ
+- [x] provare PACC come quantificatore
- [ ] aggiungere etichette in shift plot
- [ ] sistemare exact_train quapy
- [ ] testare anche su imbd
\ No newline at end of file
diff --git a/conf.yaml b/conf.yaml
index 73fe841..073e0e7 100644
--- a/conf.yaml
+++ b/conf.yaml
@@ -151,4 +151,4 @@ main_conf: &main_conf
- atc_ne
- doc_feat
-exec: *mc_conf
\ No newline at end of file
+exec: *debug_conf
\ No newline at end of file
diff --git a/quacc.log b/quacc.log
index c45692f..62b3787 100644
--- a/quacc.log
+++ b/quacc.log
@@ -2850,3 +2850,154 @@
05/11/23 14:16:15| INFO atc_mc finished [took 49.6779s]
05/11/23 14:16:19| INFO mulmc_sld finished [took 61.0610s]
05/11/23 14:16:22| INFO mulne_sld finished [took 62.2089s]
+05/11/23 14:19:02| INFO binmc_sld finished [took 225.5737s]
+05/11/23 14:19:03| INFO binne_sld finished [took 223.9017s]
+05/11/23 14:28:50| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00756) [took 806.7930s]
+05/11/23 14:29:32| INFO mul_sld_gs finished [took 848.7630s]
+05/11/23 14:36:02| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00781) [took 1240.9138s]
+05/11/23 14:39:04| INFO bin_sld_gs finished [took 1422.5520s]
+05/11/23 14:39:04| INFO Dataset sample 0.30 of dataset rcv1_CCAT_9prevs finished [took 1428.8824s]
+05/11/23 14:39:04| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs started
+05/11/23 14:39:58| INFO ref finished [took 45.7514s]
+05/11/23 14:40:02| INFO atc_mc finished [took 48.3888s]
+05/11/23 14:40:05| INFO mulmc_sld finished [took 59.0537s]
+05/11/23 14:40:09| INFO mulne_sld finished [took 60.9189s]
+05/11/23 14:42:42| INFO binne_sld finished [took 214.5464s]
+05/11/23 14:42:44| INFO binmc_sld finished [took 218.8429s]
+05/11/23 14:52:23| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1000.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00984) [took 792.5474s]
+05/11/23 14:53:05| INFO mul_sld_gs finished [took 834.1824s]
+05/11/23 14:59:56| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.01112) [took 1247.0092s]
+05/11/23 15:02:57| INFO bin_sld_gs finished [took 1427.5051s]
+05/11/23 15:02:57| INFO Dataset sample 0.40 of dataset rcv1_CCAT_9prevs finished [took 1432.9172s]
+05/11/23 15:02:57| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs started
+05/11/23 15:03:49| INFO ref finished [took 44.4148s]
+05/11/23 15:03:54| INFO atc_mc finished [took 47.7566s]
+05/11/23 15:04:00| INFO mulmc_sld finished [took 60.5480s]
+05/11/23 15:04:03| INFO mulne_sld finished [took 61.2226s]
+05/11/23 15:06:30| INFO binmc_sld finished [took 211.9647s]
+05/11/23 15:06:32| INFO binne_sld finished [took 211.4312s]
+05/11/23 15:16:00| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.00571) [took 776.6085s]
+05/11/23 15:16:42| INFO mul_sld_gs finished [took 817.9358s]
+05/11/23 15:23:24| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': 'vs', 'confidence': 'entropy'} (score=0.00653) [took 1221.6531s]
+05/11/23 15:26:23| INFO bin_sld_gs finished [took 1400.9688s]
+05/11/23 15:26:23| INFO Dataset sample 0.50 of dataset rcv1_CCAT_9prevs finished [took 1406.4620s]
+05/11/23 15:26:23| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs started
+05/11/23 15:27:16| INFO ref finished [took 44.3988s]
+05/11/23 15:27:21| INFO atc_mc finished [took 48.5589s]
+05/11/23 15:27:27| INFO mulmc_sld finished [took 61.4269s]
+05/11/23 15:27:29| INFO mulne_sld finished [took 61.8292s]
+05/11/23 15:29:55| INFO binmc_sld finished [took 210.1585s]
+05/11/23 15:29:59| INFO binne_sld finished [took 212.0930s]
+05/11/23 15:39:22| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.00616) [took 771.6071s]
+05/11/23 15:40:03| INFO mul_sld_gs finished [took 813.2905s]
+05/11/23 15:47:04| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': None, 'confidence': None} (score=0.00544) [took 1234.9832s]
+05/11/23 15:50:10| INFO bin_sld_gs finished [took 1421.7775s]
+05/11/23 15:50:10| INFO Dataset sample 0.60 of dataset rcv1_CCAT_9prevs finished [took 1427.0062s]
+05/11/23 15:50:10| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs started
+05/11/23 15:51:11| INFO ref finished [took 49.7682s]
+05/11/23 15:51:19| INFO atc_mc finished [took 54.2855s]
+05/11/23 15:51:22| INFO mulmc_sld finished [took 68.7688s]
+05/11/23 15:51:26| INFO mulne_sld finished [took 69.3711s]
+05/11/23 15:54:07| INFO binmc_sld finished [took 234.7962s]
+05/11/23 15:54:09| INFO binne_sld finished [took 234.6444s]
+05/11/23 16:03:51| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 0.1, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'bcts', 'confidence': 'entropy'} (score=0.00765) [took 811.6704s]
+05/11/23 16:04:34| INFO mul_sld_gs finished [took 854.8196s]
+05/11/23 16:11:10| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 0.1, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': 'max_conf'} (score=0.01234) [took 1252.4784s]
+05/11/23 16:14:10| INFO bin_sld_gs finished [took 1431.7446s]
+05/11/23 16:14:10| INFO Dataset sample 0.70 of dataset rcv1_CCAT_9prevs finished [took 1439.1145s]
+05/11/23 16:14:10| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs started
+05/11/23 16:15:02| INFO ref finished [took 44.0970s]
+05/11/23 16:15:07| INFO atc_mc finished [took 48.2871s]
+05/11/23 16:15:13| INFO mulmc_sld finished [took 61.0461s]
+05/11/23 16:15:15| INFO mulne_sld finished [took 60.6375s]
+05/11/23 16:17:46| INFO binmc_sld finished [took 215.1734s]
+05/11/23 16:17:49| INFO binne_sld finished [took 215.7846s]
+05/11/23 16:27:15| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': 'vs', 'confidence': None} (score=0.00822) [took 778.5688s]
+05/11/23 16:27:56| INFO mul_sld_gs finished [took 819.2615s]
+05/11/23 16:34:16| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 1.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'entropy'} (score=0.00894) [took 1200.6639s]
+05/11/23 16:37:21| INFO bin_sld_gs finished [took 1385.9035s]
+05/11/23 16:37:21| INFO Dataset sample 0.80 of dataset rcv1_CCAT_9prevs finished [took 1391.5055s]
+05/11/23 16:37:21| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs started
+05/11/23 16:38:13| INFO ref finished [took 44.7046s]
+05/11/23 16:38:18| INFO atc_mc finished [took 48.7802s]
+05/11/23 16:38:21| INFO mulmc_sld finished [took 57.4163s]
+05/11/23 16:38:24| INFO mulne_sld finished [took 58.9847s]
+05/11/23 16:40:59| INFO binmc_sld finished [took 216.7311s]
+05/11/23 16:41:01| INFO binne_sld finished [took 216.5312s]
+05/11/23 16:50:06| DEBUG [MultiClassAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 100.0, 'quantifier__classifier__class_weight': 'balanced', 'quantifier__recalib': None, 'confidence': 'max_conf'} (score=0.00808) [took 758.6896s]
+05/11/23 16:50:46| INFO mul_sld_gs finished [took 798.8038s]
+05/11/23 16:56:41| DEBUG [BinaryQuantifierAccuracyEstimator] optimization finished: best params {'quantifier__classifier__C': 10.0, 'quantifier__classifier__class_weight': None, 'quantifier__recalib': None, 'confidence': 'entropy'} (score=0.00604) [took 1154.7043s]
+05/11/23 16:59:39| INFO bin_sld_gs finished [took 1332.5521s]
+05/11/23 16:59:39| INFO Dataset sample 0.90 of dataset rcv1_CCAT_9prevs finished [took 1337.7947s]
+----------------------------------------------------------------------------------------------------
+05/11/23 20:08:46| ERROR estimate comparison failed. Exceprion: 'environ' object has no attribute 'OUT_PATH'
+----------------------------------------------------------------------------------------------------
+05/11/23 20:09:08| ERROR estimate comparison failed. Exceprion: 'environ' object has no attribute 'OUT_PATH'
+----------------------------------------------------------------------------------------------------
+05/11/23 20:09:27| INFO dataset imdb_3prevs
+05/11/23 20:09:34| INFO Dataset sample 0.20 of dataset imdb_3prevs started
+05/11/23 20:09:44| INFO ref finished [took 8.9550s]
+05/11/23 20:09:47| INFO atc_mc finished [took 11.8923s]
+05/11/23 20:09:56| INFO mulmc_sld finished [took 21.3196s]
+05/11/23 20:09:56| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 21.7709s]
+05/11/23 20:09:56| INFO Dataset sample 0.50 of dataset imdb_3prevs started
+05/11/23 20:10:05| INFO ref finished [took 8.6116s]
+05/11/23 20:10:08| INFO atc_mc finished [took 11.6880s]
+05/11/23 20:10:16| INFO mulmc_sld finished [took 19.7793s]
+05/11/23 20:10:16| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.3246s]
+05/11/23 20:10:16| INFO Dataset sample 0.80 of dataset imdb_3prevs started
+05/11/23 20:10:26| INFO ref finished [took 8.6654s]
+05/11/23 20:10:29| INFO atc_mc finished [took 11.6975s]
+05/11/23 20:10:35| INFO mulmc_sld finished [took 18.1478s]
+05/11/23 20:10:35| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.7200s]
+----------------------------------------------------------------------------------------------------
+05/11/23 20:11:42| INFO dataset imdb_3prevs
+05/11/23 20:11:49| INFO Dataset sample 0.20 of dataset imdb_3prevs started
+05/11/23 20:11:58| INFO ref finished [took 8.7146s]
+05/11/23 20:12:02| INFO atc_mc finished [took 11.9672s]
+05/11/23 20:12:10| INFO mulmc_sld finished [took 20.7824s]
+05/11/23 20:12:10| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 21.2293s]
+05/11/23 20:12:10| INFO Dataset sample 0.50 of dataset imdb_3prevs started
+05/11/23 20:12:19| INFO ref finished [took 8.5867s]
+05/11/23 20:12:23| INFO atc_mc finished [took 11.6542s]
+05/11/23 20:12:30| INFO mulmc_sld finished [took 19.6709s]
+05/11/23 20:12:30| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.1802s]
+05/11/23 20:12:30| INFO Dataset sample 0.80 of dataset imdb_3prevs started
+05/11/23 20:12:40| INFO ref finished [took 8.7231s]
+05/11/23 20:12:43| INFO atc_mc finished [took 11.8244s]
+05/11/23 20:12:49| INFO mulmc_sld finished [took 18.0420s]
+05/11/23 20:12:49| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.6102s]
+----------------------------------------------------------------------------------------------------
+05/11/23 20:14:32| INFO dataset imdb_3prevs
+05/11/23 20:14:39| INFO Dataset sample 0.20 of dataset imdb_3prevs started
+05/11/23 20:14:48| INFO ref finished [took 8.6247s]
+05/11/23 20:14:51| INFO atc_mc finished [took 11.6363s]
+05/11/23 20:15:00| INFO mulmc_sld finished [took 20.4634s]
+05/11/23 20:15:00| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 20.9026s]
+05/11/23 20:15:00| INFO Dataset sample 0.50 of dataset imdb_3prevs started
+05/11/23 20:15:09| INFO ref finished [took 8.5219s]
+05/11/23 20:15:12| INFO atc_mc finished [took 11.6739s]
+05/11/23 20:15:20| INFO mulmc_sld finished [took 19.8454s]
+05/11/23 20:15:20| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.3705s]
+05/11/23 20:15:20| INFO Dataset sample 0.80 of dataset imdb_3prevs started
+05/11/23 20:15:29| INFO ref finished [took 8.5948s]
+05/11/23 20:15:32| INFO atc_mc finished [took 11.7465s]
+05/11/23 20:15:39| INFO mulmc_sld finished [took 17.9276s]
+05/11/23 20:15:39| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.4893s]
+----------------------------------------------------------------------------------------------------
+05/11/23 20:16:10| INFO dataset imdb_3prevs
+05/11/23 20:16:17| INFO Dataset sample 0.20 of dataset imdb_3prevs started
+05/11/23 20:16:26| INFO ref finished [took 8.3736s]
+05/11/23 20:16:29| INFO atc_mc finished [took 11.3995s]
+05/11/23 20:16:38| INFO mulmc_sld finished [took 20.4916s]
+05/11/23 20:16:38| INFO Dataset sample 0.20 of dataset imdb_3prevs finished [took 20.9187s]
+05/11/23 20:16:38| INFO Dataset sample 0.50 of dataset imdb_3prevs started
+05/11/23 20:16:47| INFO ref finished [took 8.4368s]
+05/11/23 20:16:50| INFO atc_mc finished [took 11.4889s]
+05/11/23 20:16:58| INFO mulmc_sld finished [took 19.6803s]
+05/11/23 20:16:58| INFO Dataset sample 0.50 of dataset imdb_3prevs finished [took 20.2091s]
+05/11/23 20:16:58| INFO Dataset sample 0.80 of dataset imdb_3prevs started
+05/11/23 20:17:08| INFO ref finished [took 8.9281s]
+05/11/23 20:17:11| INFO atc_mc finished [took 11.9333s]
+05/11/23 20:17:17| INFO mulmc_sld finished [took 18.2367s]
+05/11/23 20:17:17| INFO Dataset sample 0.80 of dataset imdb_3prevs finished [took 18.8309s]
diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py
index 8b74071..7421a2b 100644
--- a/quacc/evaluation/report.py
+++ b/quacc/evaluation/report.py
@@ -182,22 +182,21 @@ class CompReport:
train_prev=self.train_prev,
)
elif mode == "shift":
- shift_data = (
- self.shift_data(metric=metric, estimators=estimators)
- .groupby(level=0)
- .mean()
- )
+ _shift_data = self.shift_data(metric=metric, estimators=estimators)
+ shift_avg = _shift_data.groupby(level=0).mean()
+ shift_counts = _shift_data.groupby(level=0).count()
shift_prevs = np.around(
- [(1.0 - p, p) for p in np.sort(shift_data.index.unique(0))],
+ [(1.0 - p, p) for p in np.sort(shift_avg.index.unique(0))],
decimals=2,
)
return plot.plot_shift(
shift_prevs=shift_prevs,
- columns=shift_data.columns.to_numpy(),
- data=shift_data.T.to_numpy(),
+ columns=shift_avg.columns.to_numpy(),
+ data=shift_avg.T.to_numpy(),
metric=metric,
name=conf,
train_prev=self.train_prev,
+ counts=shift_counts.T.to_numpy(),
)
def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str:
@@ -374,6 +373,7 @@ class DatasetReport:
res += "### avg dataset shift\n"
avg_shift = _shift_data.groupby(level=0).mean()
+ count_shift = _shift_data.groupby(level=0).count()
prevs_shift = np.sort(avg_shift.index.unique(0))
shift_op = plot.plot_shift(
@@ -383,6 +383,7 @@ class DatasetReport:
metric=metric,
name=conf,
train_prev=None,
+ counts=count_shift.T.to_numpy(),
)
res += f"![plot_shift]({shift_op.relative_to(env.OUT_DIR).as_posix()})\n"
diff --git a/quacc/plot.py b/quacc/plot.py
index f34fb04..4ccb3fe 100644
--- a/quacc/plot.py
+++ b/quacc/plot.py
@@ -27,7 +27,6 @@ def plot_delta(
metric="acc",
name="default",
train_prev=None,
- fit_scores=None,
legend=True,
avg=None,
) -> Path:
@@ -75,14 +74,6 @@ def plot_delta(
color=_cy["color"],
alpha=0.25,
)
- if fit_scores is not None and method in fit_scores:
- ax.plot(
- base_prevs,
- np.repeat(fit_scores[method], base_prevs.shape[0]),
- color=_cy["color"],
- linestyle="--",
- markersize=0,
- )
x_label = "test" if avg is None or avg == "train" else "train"
ax.set(
@@ -188,11 +179,11 @@ def plot_shift(
columns,
data,
*,
+ counts=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
- fit_scores=None,
legend=True,
) -> Path:
if train_prev is not None:
@@ -223,15 +214,20 @@ def plot_shift(
markersize=3,
zorder=2,
)
-
- if fit_scores is not None and method in fit_scores:
- ax.plot(
- shift_prevs,
- np.repeat(fit_scores[method], shift_prevs.shape[0]),
- color=_cy["color"],
- linestyle="--",
- markersize=0,
- )
+ if counts is not None:
+ _col_idx = np.where(columns == method)[0]
+ count = counts[_col_idx].flatten()
+ for prev, shift, cnt in zip(shift_prevs, shifts, count):
+ label = f"{cnt}"
+ plt.annotate(
+ label,
+ (prev, shift),
+ textcoords="offset points",
+ xytext=(0, 10),
+ ha="center",
+ color=_cy["color"],
+ fontsize=12.0,
+ )
ax.set(xlabel="dataset shift", ylabel=metric, title=title)