Merge pull request #4 from lorenzovolpi/grid_search

Grid search implemented for AccuracyEstimator
This commit is contained in:
Lorenzo Volpi 2023-11-07 08:28:32 +01:00 committed by GitHub
commit c52dc6498f
50 changed files with 8376 additions and 626 deletions

15
.gitignore vendored
View File

@ -1,17 +1,20 @@
*.code-workspace
quavenv/*
*.pdf
__pycache__/*
baselines/__pycache__/*
baselines/densratio/__pycache__/*
quacc/__pycache__/*
quacc/evaluation/__pycache__/*
quacc/method/__pycache__/*
tests/__pycache__/*
garg22_ATC/__pycache__/*
guillory21_doc/__pycache__/*
jiang18_trustscore/__pycache__/*
lipton_bbse/__pycache__/*
elsahar19_rca/__pycache__/*
*.coverage
.coverage
scp_sync.py
out/*
output/*
*.log
!output/main/

9
.vscode/launch.json vendored
View File

@ -5,7 +5,6 @@
"version": "0.2.0",
"configurations": [
{
"name": "main",
"type": "python",
@ -15,12 +14,12 @@
"justMyCode": true
},
{
"name": "models",
"name": "main_test",
"type": "python",
"request": "launch",
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\baselines\\models.py",
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main_test.py",
"console": "integratedTerminal",
"justMyCode": true
}
"justMyCode": false
},
]
}

View File

@ -1,14 +1,5 @@
{
"todo": [
{
"assignedTo": {
"name": "Lorenzo Volpi"
},
"creation_time": "2023-10-28T14:34:46.226Z",
"id": "4",
"references": [],
"title": "Aggingere estimator basati su PACC (quantificatore)"
},
{
"assignedTo": {
"name": "Lorenzo Volpi"
@ -18,15 +9,6 @@
"references": [],
"title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence"
},
{
"assignedTo": {
"name": "Lorenzo Volpi"
},
"creation_time": "2023-10-28T14:34:23.217Z",
"id": "3",
"references": [],
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
},
{
"assignedTo": {
"name": "Lorenzo Volpi"
@ -38,6 +20,27 @@
}
],
"in-progress": [
{
"assignedTo": {
"name": "Lorenzo Volpi"
},
"creation_time": "2023-10-28T14:34:23.217Z",
"id": "3",
"references": [],
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
},
{
"assignedTo": {
"name": "Lorenzo Volpi"
},
"creation_time": "2023-10-28T14:34:46.226Z",
"id": "4",
"references": [],
"title": "Aggingere estimator basati su PACC (quantificatore)"
}
],
"testing": [],
"done": [
{
"assignedTo": {
"name": "Lorenzo Volpi"
@ -47,7 +50,5 @@
"references": [],
"title": "Rework rappresentazione dati di report"
}
],
"testing": [],
"done": []
]
}

View File

@ -103,17 +103,35 @@ verbose=True).fit(V_tr)</li>
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> import baselines</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox"type="checkbox"> importare mandoline</p>
<ul>
<li>mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc</li>
</ul>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox"type="checkbox"> sistemare vecchie iw baselines</p>
<ul>
<li>non possono essere fixate perché dipendono da numpy</li>
</ul>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> plot avg con train prevalence sull'asse x e media su test prevalecne</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> realizzare grid search per task specifico partendo da GridSearchQ</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> provare PACC come quantificatore</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> aggiungere etichette in shift plot</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> sistemare exact_train quapy</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox"type="checkbox"> testare anche su imbd</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox"type="checkbox"> plot avg con train prevalence sull'asse x e media su test prevalecne</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox"type="checkbox"> realizzare grid search per task specifico partendo da GridSearchQ</p>
</li>
<li class="task-list-item enabled">
<p><input class="task-list-item-checkbox"type="checkbox"> provare PACC come quantificatore</p>
</li>
</ul>

14
TODO.md
View File

@ -30,7 +30,13 @@
- nel caso di bin fare media dei due best score
- [x] import baselines
- [ ] testare anche su imbd
- [ ] plot avg con train prevalence sull'asse x e media su test prevalecne
- [ ] realizzare grid search per task specifico partendo da GridSearchQ
- [ ] provare PACC come quantificatore
- [ ] importare mandoline
- mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc
- [ ] sistemare vecchie iw baselines
- non possono essere fixate perché dipendono da numpy
- [x] plot avg con train prevalence sull'asse x e media su test prevalecne
- [x] realizzare grid search per task specifico partendo da GridSearchQ
- [x] provare PACC come quantificatore
- [x] aggiungere etichette in shift plot
- [x] sistemare exact_train quapy
- [ ] testare anche su imbd

View File

@ -1,41 +1,44 @@
import numpy as np
import numpy as np
from sklearn.metrics import f1_score
def get_entropy(probs):
return np.sum( np.multiply(probs, np.log(probs + 1e-20)) , axis=1)
def get_entropy(probs):
return np.sum(np.multiply(probs, np.log(probs + 1e-20)), axis=1)
def get_max_conf(probs):
return np.max(probs, axis=-1)
def find_ATC_threshold(scores, labels):
return np.max(probs, axis=-1)
def find_ATC_threshold(scores, labels):
sorted_idx = np.argsort(scores)
sorted_scores = scores[sorted_idx]
sorted_labels = labels[sorted_idx]
fp = np.sum(labels==0)
fp = np.sum(labels == 0)
fn = 0.0
min_fp_fn = np.abs(fp - fn)
thres = 0.0
for i in range(len(labels)):
if sorted_labels[i] == 0:
for i in range(len(labels)):
if sorted_labels[i] == 0:
fp -= 1
else:
else:
fn += 1
if np.abs(fp - fn) < min_fp_fn:
if np.abs(fp - fn) < min_fp_fn:
min_fp_fn = np.abs(fp - fn)
thres = sorted_scores[i]
return min_fp_fn, thres
def get_ATC_acc(thres, scores):
return np.mean(scores>=thres)
def get_ATC_acc(thres, scores):
return np.mean(scores >= thres)
def get_ATC_f1(thres, scores, probs):
preds = np.argmax(probs, axis=-1)
estim_y = abs(1 - (scores>=thres)^preds)
estim_y = np.abs(1 - (scores >= thres) ^ preds)
return f1_score(estim_y, preds)

View File

@ -4,6 +4,20 @@ from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
from baselines import densratio
from baselines.pykliep import DensityRatioEstimator
def kliep(Xtr, ytr, Xte):
kliep = DensityRatioEstimator()
kliep.fit(Xtr, Xte)
return kliep.predict(Xtr)
def usilf(Xtr, ytr, Xte, alpha=0.0):
dense_ratio_obj = densratio(Xtr, Xte, alpha=alpha, verbose=False)
return dense_ratio_obj.compute_density_ratio(Xtr)
def logreg(Xtr, ytr, Xte):
# check "Direct Density Ratio Estimation for

View File

@ -123,7 +123,7 @@ if __name__ == "__main__":
results = []
for sample in protocol():
wx = iw.logreg(d.validation.X, d.validation.y, sample.X)
wx = iw.kliep(d.validation.X, d.validation.y, sample.X)
test_preds = lr.predict(sample.X)
estim_acc = np.sum((1.0 * (val_preds == d.validation.y)) * wx) / np.sum(wx)
true_acc = metrics.accuracy_score(sample.y, test_preds)

View File

@ -74,7 +74,9 @@ class DensityRatioEstimator:
# X_test_shuffled = X_test.copy()
X_test_shuffled = X_test.copy()
np.random.shuffle(X_test_shuffled)
X_test_index = np.arange(X_test_shuffled.shape[0])
np.random.shuffle(X_test_index)
X_test_shuffled = X_test_shuffled[X_test_index, :]
j_scores = {}

193
conf.yaml
View File

@ -4,56 +4,46 @@ debug_conf: &debug_conf
- acc
DATASET_N_PREVS: 5
DATASET_PREVS:
# - 0.2
- 0.5
- 0.1
# - 0.8
confs:
- DATASET_NAME: imdb
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
plot_confs:
debug:
PLOT_ESTIMATORS:
- ref
- mulmc_sld
- atc_mc
- atc_ne
PLOT_STDEV: true
debug_plus:
PLOT_ESTIMATORS:
- mul_sld_bcts
- mul_sld
- ref
- atc_mc
- atc_ne
test_conf: &test_conf
mc_conf: &mc_conf
global:
METRICS:
- acc
- f1
DATASET_N_PREVS: 2
DATASET_PREVS:
- 0.5
- 0.1
DATASET_N_PREVS: 9
DATASET_DIR_UPDATE: true
confs:
# - DATASET_NAME: rcv1
# DATASET_TARGET: CCAT
- DATASET_NAME: imdb
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
# - DATASET_NAME: imdb
plot_confs:
best_vs_atc:
debug3:
PLOT_ESTIMATORS:
- bin_sld
- bin_sld_bcts
- bin_sld_gs
- mul_sld
- mul_sld_bcts
- mul_sld_gs
- ref
- atc_mc
- atc_ne
- binmc_sld
- mulmc_sld
- binne_sld
- mulne_sld
- bin_sld_gs
- mul_sld_gs
- atc_mc
PLOT_STDEV: true
main_conf: &main_conf
test_conf: &test_conf
global:
METRICS:
- acc
@ -63,29 +53,160 @@ main_conf: &main_conf
confs:
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
confs_bck:
# - DATASET_NAME: imdb
plot_confs:
gs_vs_gsq:
PLOT_ESTIMATORS:
- bin_sld
- bin_sld_gs
- bin_sld_gsq
- mul_sld
- mul_sld_gs
- mul_sld_gsq
gs_vs_atc:
PLOT_ESTIMATORS:
- bin_sld
- bin_sld_gs
- mul_sld
- mul_sld_gs
- atc_mc
- atc_ne
sld_vs_pacc:
PLOT_ESTIMATORS:
- bin_sld
- bin_sld_gs
- mul_sld
- mul_sld_gs
- bin_pacc
- bin_pacc_gs
- mul_pacc
- mul_pacc_gs
- atc_mc
- atc_ne
pacc_vs_atc:
PLOT_ESTIMATORS:
- bin_pacc
- bin_pacc_gs
- mul_pacc
- mul_pacc_gs
- atc_mc
- atc_ne
main_conf: &main_conf
global:
METRICS:
- acc
- f1
DATASET_N_PREVS: 9
DATASET_DIR_UPDATE: true
confs:
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
- DATASET_NAME: imdb
confs_next:
- DATASET_NAME: rcv1
DATASET_TARGET: GCAT
- DATASET_NAME: rcv1
DATASET_TARGET: MCAT
plot_confs:
gs_vs_qgs:
PLOT_ESTIMATORS:
- mul_sld_gs
- bin_sld_gs
- mul_sld_gsq
- bin_sld_gsq
- atc_mc
- atc_ne
PLOT_STDEV: true
plot_confs_completed:
max_conf_vs_atc_pacc:
PLOT_ESTIMATORS:
- bin_pacc
- binmc_pacc
- mul_pacc
- mulmc_pacc
- atc_mc
PLOT_STDEV: true
max_conf_vs_entropy_pacc:
PLOT_ESTIMATORS:
- binmc_pacc
- binne_pacc
- mulmc_pacc
- mulne_pacc
- atc_mc
PLOT_STDEV: true
gs_vs_atc:
PLOT_ESTIMATORS:
- mul_sld_gs
- bin_sld_gs
- ref
- mul_pacc_gs
- bin_pacc_gs
- atc_mc
- atc_ne
PLOT_STDEV: true
gs_vs_all:
PLOT_ESTIMATORS:
- mul_sld_gs
- bin_sld_gs
- mul_pacc_gs
- bin_pacc_gs
- atc_mc
- doc_feat
- kfcv
PLOT_STDEV: true
gs_vs_qgs:
PLOT_ESTIMATORS:
- mul_sld_gs
- bin_sld_gs
- mul_sld_gsq
- bin_sld_gsq
- atc_mc
- atc_ne
PLOT_STDEV: true
cc_vs_other:
PLOT_ESTIMATORS:
- mul_cc
- bin_cc
- mul_sld
- bin_sld
- mul_pacc
- bin_pacc
PLOT_STDEV: true
max_conf_vs_atc:
PLOT_ESTIMATORS:
- bin_sld
- binmc_sld
- mul_sld
- mulmc_sld
- atc_mc
PLOT_STDEV: true
max_conf_vs_entropy:
PLOT_ESTIMATORS:
- binmc_sld
- binne_sld
- mulmc_sld
- mulne_sld
- atc_mc
PLOT_STDEV: true
sld_vs_pacc:
PLOT_ESTIMATORS:
- bin_sld
- mul_sld
- bin_pacc
- mul_pacc
- atc_mc
PLOT_STDEV: true
plot_confs_other:
best_vs_atc:
PLOT_ESTIMATORS:
- mul_sld_bcts
- mul_sld_gs
- bin_sld_bcts
- bin_sld_gs
- ref
- atc_mc
- atc_ne
all_vs_atc:
@ -96,7 +217,6 @@ main_conf: &main_conf
- mul_sld
- mul_sld_bcts
- mul_sld_gs
- ref
- atc_mc
- atc_ne
best_vs_all:
@ -105,10 +225,9 @@ main_conf: &main_conf
- bin_sld_gs
- mul_sld_bcts
- mul_sld_gs
- ref
- kfcv
- atc_mc
- atc_ne
- doc_feat
exec: *debug_conf
exec: *main_conf

217
poetry.lock generated
View File

@ -15,6 +15,104 @@ numpy = ">=1.9"
scikit-learn = ">=0.20.0"
scipy = ">=1.1.0"
[[package]]
name = "bcrypt"
version = "4.0.1"
description = "Modern password hashing for your software and your servers"
optional = false
python-versions = ">=3.6"
files = [
{file = "bcrypt-4.0.1-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:b1023030aec778185a6c16cf70f359cbb6e0c289fd564a7cfa29e727a1c38f8f"},
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:08d2947c490093a11416df18043c27abe3921558d2c03e2076ccb28a116cb6d0"},
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0eaa47d4661c326bfc9d08d16debbc4edf78778e6aaba29c1bc7ce67214d4410"},
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae88eca3024bb34bb3430f964beab71226e761f51b912de5133470b649d82344"},
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:a522427293d77e1c29e303fc282e2d71864579527a04ddcfda6d4f8396c6c36a"},
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:fbdaec13c5105f0c4e5c52614d04f0bca5f5af007910daa8b6b12095edaa67b3"},
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:ca3204d00d3cb2dfed07f2d74a25f12fc12f73e606fcaa6975d1f7ae69cacbb2"},
{file = "bcrypt-4.0.1-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:089098effa1bc35dc055366740a067a2fc76987e8ec75349eb9484061c54f535"},
{file = "bcrypt-4.0.1-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:e9a51bbfe7e9802b5f3508687758b564069ba937748ad7b9e890086290d2f79e"},
{file = "bcrypt-4.0.1-cp36-abi3-win32.whl", hash = "sha256:2caffdae059e06ac23fce178d31b4a702f2a3264c20bfb5ff541b338194d8fab"},
{file = "bcrypt-4.0.1-cp36-abi3-win_amd64.whl", hash = "sha256:8a68f4341daf7522fe8d73874de8906f3a339048ba406be6ddc1b3ccb16fc0d9"},
{file = "bcrypt-4.0.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf4fa8b2ca74381bb5442c089350f09a3f17797829d958fad058d6e44d9eb83c"},
{file = "bcrypt-4.0.1-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:67a97e1c405b24f19d08890e7ae0c4f7ce1e56a712a016746c8b2d7732d65d4b"},
{file = "bcrypt-4.0.1-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b3b85202d95dd568efcb35b53936c5e3b3600c7cdcc6115ba461df3a8e89f38d"},
{file = "bcrypt-4.0.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbb03eec97496166b704ed663a53680ab57c5084b2fc98ef23291987b525cb7d"},
{file = "bcrypt-4.0.1-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:5ad4d32a28b80c5fa6671ccfb43676e8c1cc232887759d1cd7b6f56ea4355215"},
{file = "bcrypt-4.0.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b57adba8a1444faf784394de3436233728a1ecaeb6e07e8c22c8848f179b893c"},
{file = "bcrypt-4.0.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:705b2cea8a9ed3d55b4491887ceadb0106acf7c6387699fca771af56b1cdeeda"},
{file = "bcrypt-4.0.1-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:2b3ac11cf45161628f1f3733263e63194f22664bf4d0c0f3ab34099c02134665"},
{file = "bcrypt-4.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3100851841186c25f127731b9fa11909ab7b1df6fc4b9f8353f4f1fd952fbf71"},
{file = "bcrypt-4.0.1.tar.gz", hash = "sha256:27d375903ac8261cfe4047f6709d16f7d18d39b1ec92aaf72af989552a650ebd"},
]
[package.extras]
tests = ["pytest (>=3.2.1,!=3.3.0)"]
typecheck = ["mypy"]
[[package]]
name = "cffi"
version = "1.16.0"
description = "Foreign Function Interface for Python calling C code."
optional = false
python-versions = ">=3.8"
files = [
{file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"},
{file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"},
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"},
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"},
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"},
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"},
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"},
{file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"},
{file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"},
{file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"},
{file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"},
{file = "cffi-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404"},
{file = "cffi-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417"},
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627"},
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936"},
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d"},
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56"},
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e"},
{file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc"},
{file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb"},
{file = "cffi-1.16.0-cp311-cp311-win32.whl", hash = "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab"},
{file = "cffi-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba"},
{file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"},
{file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"},
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"},
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"},
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"},
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"},
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"},
{file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"},
{file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"},
{file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"},
{file = "cffi-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc"},
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0"},
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b"},
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c"},
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b"},
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324"},
{file = "cffi-1.16.0-cp38-cp38-win32.whl", hash = "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a"},
{file = "cffi-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36"},
{file = "cffi-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed"},
{file = "cffi-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2"},
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872"},
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8"},
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f"},
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4"},
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098"},
{file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000"},
{file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe"},
{file = "cffi-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4"},
{file = "cffi-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8"},
{file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"},
]
[package.dependencies]
pycparser = "*"
[[package]]
name = "colorama"
version = "0.4.6"
@ -223,6 +321,51 @@ files = [
[package.extras]
toml = ["tomli"]
[[package]]
name = "cryptography"
version = "41.0.5"
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
optional = false
python-versions = ">=3.7"
files = [
{file = "cryptography-41.0.5-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:da6a0ff8f1016ccc7477e6339e1d50ce5f59b88905585f77193ebd5068f1e797"},
{file = "cryptography-41.0.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b948e09fe5fb18517d99994184854ebd50b57248736fd4c720ad540560174ec5"},
{file = "cryptography-41.0.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e6031e113b7421db1de0c1b1f7739564a88f1684c6b89234fbf6c11b75147"},
{file = "cryptography-41.0.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e270c04f4d9b5671ebcc792b3ba5d4488bf7c42c3c241a3748e2599776f29696"},
{file = "cryptography-41.0.5-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ec3b055ff8f1dce8e6ef28f626e0972981475173d7973d63f271b29c8a2897da"},
{file = "cryptography-41.0.5-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:7d208c21e47940369accfc9e85f0de7693d9a5d843c2509b3846b2db170dfd20"},
{file = "cryptography-41.0.5-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:8254962e6ba1f4d2090c44daf50a547cd5f0bf446dc658a8e5f8156cae0d8548"},
{file = "cryptography-41.0.5-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:a48e74dad1fb349f3dc1d449ed88e0017d792997a7ad2ec9587ed17405667e6d"},
{file = "cryptography-41.0.5-cp37-abi3-win32.whl", hash = "sha256:d3977f0e276f6f5bf245c403156673db103283266601405376f075c849a0b936"},
{file = "cryptography-41.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:73801ac9736741f220e20435f84ecec75ed70eda90f781a148f1bad546963d81"},
{file = "cryptography-41.0.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3be3ca726e1572517d2bef99a818378bbcf7d7799d5372a46c79c29eb8d166c1"},
{file = "cryptography-41.0.5-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:e886098619d3815e0ad5790c973afeee2c0e6e04b4da90b88e6bd06e2a0b1b72"},
{file = "cryptography-41.0.5-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:573eb7128cbca75f9157dcde974781209463ce56b5804983e11a1c462f0f4e88"},
{file = "cryptography-41.0.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0c327cac00f082013c7c9fb6c46b7cc9fa3c288ca702c74773968173bda421bf"},
{file = "cryptography-41.0.5-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:227ec057cd32a41c6651701abc0328135e472ed450f47c2766f23267b792a88e"},
{file = "cryptography-41.0.5-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:22892cc830d8b2c89ea60148227631bb96a7da0c1b722f2aac8824b1b7c0b6b8"},
{file = "cryptography-41.0.5-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:5a70187954ba7292c7876734183e810b728b4f3965fbe571421cb2434d279179"},
{file = "cryptography-41.0.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:88417bff20162f635f24f849ab182b092697922088b477a7abd6664ddd82291d"},
{file = "cryptography-41.0.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c707f7afd813478e2019ae32a7c49cd932dd60ab2d2a93e796f68236b7e1fbf1"},
{file = "cryptography-41.0.5-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:580afc7b7216deeb87a098ef0674d6ee34ab55993140838b14c9b83312b37b86"},
{file = "cryptography-41.0.5-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:fba1e91467c65fe64a82c689dc6cf58151158993b13eb7a7f3f4b7f395636723"},
{file = "cryptography-41.0.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0d2a6a598847c46e3e321a7aef8af1436f11c27f1254933746304ff014664d84"},
{file = "cryptography-41.0.5.tar.gz", hash = "sha256:392cb88b597247177172e02da6b7a63deeff1937fa6fec3bbf902ebd75d97ec7"},
]
[package.dependencies]
cffi = ">=1.12"
[package.extras]
docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"]
docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
nox = ["nox"]
pep8test = ["black", "check-sdist", "mypy", "ruff"]
sdist = ["build"]
ssh = ["bcrypt (>=3.1.5)"]
test = ["pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
test-randomorder = ["pytest-randomly"]
[[package]]
name = "cycler"
version = "0.11.0"
@ -728,6 +871,27 @@ sql-other = ["SQLAlchemy (>=1.4.36)"]
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.8.0)"]
[[package]]
name = "paramiko"
version = "3.3.1"
description = "SSH2 protocol library"
optional = false
python-versions = ">=3.6"
files = [
{file = "paramiko-3.3.1-py3-none-any.whl", hash = "sha256:b7bc5340a43de4287bbe22fe6de728aa2c22468b2a849615498dd944c2f275eb"},
{file = "paramiko-3.3.1.tar.gz", hash = "sha256:6a3777a961ac86dbef375c5f5b8d50014a1a96d0fd7f054a43bc880134b0ff77"},
]
[package.dependencies]
bcrypt = ">=3.2"
cryptography = ">=3.3"
pynacl = ">=1.5"
[package.extras]
all = ["gssapi (>=1.4.1)", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"]
gssapi = ["gssapi (>=1.4.1)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"]
invoke = ["invoke (>=2.0)"]
[[package]]
name = "pillow"
version = "10.0.1"
@ -851,6 +1015,17 @@ files = [
[package.dependencies]
numpy = ">=1.16.6"
[[package]]
name = "pycparser"
version = "2.21"
description = "C parser in Python"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
]
[[package]]
name = "pylance"
version = "0.5.10"
@ -872,6 +1047,32 @@ pyarrow = ">=10"
[package.extras]
tests = ["duckdb", "ml_dtypes", "pandas (>=1.4)", "polars[pandas,pyarrow]", "pytest", "tensorflow"]
[[package]]
name = "pynacl"
version = "1.5.0"
description = "Python binding to the Networking and Cryptography (NaCl) library"
optional = false
python-versions = ">=3.6"
files = [
{file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"},
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92"},
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394"},
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d"},
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858"},
{file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b"},
{file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff"},
{file = "PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543"},
{file = "PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93"},
{file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"},
]
[package.dependencies]
cffi = ">=1.4.1"
[package.extras]
docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"]
tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
[[package]]
name = "pyparsing"
version = "3.1.1"
@ -1172,6 +1373,20 @@ files = [
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]
[[package]]
name = "tabulate"
version = "0.9.0"
description = "Pretty-print tabular data"
optional = false
python-versions = ">=3.7"
files = [
{file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
{file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
]
[package.extras]
widechars = ["wcwidth"]
[[package]]
name = "threadpoolctl"
version = "3.2.0"
@ -1271,4 +1486,4 @@ test = ["pytest", "pytest-cov"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "54d9922f6d48a46f554a6b350ce09d668a88755efb1fbf295f8f8a0a411bdef2"
content-hash = "8ce63f546a9858b6678a3eb3925d6f629cbf0c95e4ee8bdeeb3415ce184ffbc9"

View File

@ -25,6 +25,8 @@ pylance = "^0.5.9"
pytest-mock = "^3.11.1"
pytest-cov = "^4.1.0"
win11toast = "^0.32"
tabulate = "^0.9.0"
paramiko = "^3.3.1"
[tool.pytest.ini_options]
addopts = "--cov=quacc --capture=tee-sys"

4671
quacc.log Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
import math
from typing import List, Optional
import numpy as np
import math
import scipy.sparse as sp
from quapy.data import LabelledCollection
@ -128,7 +128,9 @@ class ExtendedCollection(LabelledCollection):
@classmethod
def extend_collection(
cls, base: LabelledCollection, pred_proba: np.ndarray
cls,
base: LabelledCollection,
pred_proba: np.ndarray,
):
n_classes = base.n_classes
@ -136,13 +138,13 @@ class ExtendedCollection(LabelledCollection):
n_x = cls.extend_instances(base.X, pred_proba)
# n_y = (exptected y, predicted y)
pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba])
pred_proba = pred_proba[:, -n_classes:]
preds = np.argmax(pred_proba, axis=-1)
n_y = np.asarray(
[
ExClassManager.get_ex(n_classes, true_class, pred_class)
for (true_class, pred_class) in zip(base.y, pred)
for (true_class, pred_class) in zip(base.y, preds)
]
)
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])

View File

@ -71,18 +71,16 @@ class Dataset:
return all_train, test
def get_raw(self, validation=True) -> DatasetSample:
def get_raw(self) -> DatasetSample:
all_train, test = {
"spambase": self.__spambase,
"imdb": self.__imdb,
"rcv1": self.__rcv1,
}[self._name]()
train, val = all_train, None
if validation:
train, val = all_train.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0
)
train, val = all_train.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0
)
return DatasetSample(train, val, test)

View File

@ -1,13 +1,10 @@
import quapy as qp
import numpy as np
def from_name(err_name):
if err_name == "f1e":
return f1e
elif err_name == "f1":
return f1
else:
return qp.error.from_name(err_name)
assert err_name in ERROR_NAMES, f"unknown error {err_name}"
callable_error = globals()[err_name]
return callable_error
# def f1(prev):
@ -36,5 +33,23 @@ def f1e(prev):
return 1 - f1(prev)
def acc(prev):
return (prev[0] + prev[3]) / sum(prev)
def acc(prev: np.ndarray) -> float:
return (prev[0] + prev[3]) / np.sum(prev)
def accd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> np.ndarray:
vacc = np.vectorize(acc, signature="(m)->()")
a_tp = vacc(true_prevs)
a_ep = vacc(estim_prevs)
return np.abs(a_tp - a_ep)
def maccd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> float:
return accd(true_prevs, estim_prevs).mean()
ACCURACY_ERROR = {maccd}
ACCURACY_ERROR_SINGLE = {accd}
ACCURACY_ERROR_NAMES = {func.__name__ for func in ACCURACY_ERROR}
ACCURACY_ERROR_SINGLE_NAMES = {func.__name__ for func in ACCURACY_ERROR_SINGLE}
ERROR_NAMES = ACCURACY_ERROR_NAMES | ACCURACY_ERROR_SINGLE_NAMES

View File

@ -1,192 +0,0 @@
import math
from abc import abstractmethod
import numpy as np
import quapy as qp
from quapy.data import LabelledCollection
from quapy.method.aggregative import CC, SLD
from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from quacc.data import ExtendedCollection
class AccuracyEstimator:
def __init__(self):
self.fit_score = None
def _gs_params(self, t_val: LabelledCollection):
return {
"param_grid": {
"classifier__C": np.logspace(-3, 3, 7),
"classifier__class_weight": [None, "balanced"],
"recalib": [None, "bcts"],
},
"protocol": UPP(t_val, repeats=1000),
"error": qp.error.mae,
"refit": False,
"timeout": -1,
"n_jobs": None,
"verbose": True,
}
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if not pred_proba:
pred_proba = self.c_model.predict_proba(base.X)
return ExtendedCollection.extend_collection(base, pred_proba), pred_proba
@abstractmethod
def fit(self, train: LabelledCollection | ExtendedCollection):
...
@abstractmethod
def estimate(self, instances, ext=False):
...
AE = AccuracyEstimator
class MulticlassAccuracyEstimator(AccuracyEstimator):
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
super().__init__()
self.c_model = c_model
self._q_model_name = q_model.upper()
self.e_train = None
self.gs = gs
self.recalib = recalib
def fit(self, train: LabelledCollection | ExtendedCollection):
# check if model is fit
# self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection):
pred_prob_train = cross_val_predict(
self.c_model, *train.Xy, method="predict_proba"
)
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
else:
self.e_train = train
if self._q_model_name == "SLD":
if self.gs:
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
gs_params = self._gs_params(t_val)
self.q_model = GridSearchQ(
SLD(LogisticRegression()),
**gs_params,
)
self.q_model.fit(t_train)
self.fit_score = self.q_model.best_score_
else:
self.q_model = SLD(LogisticRegression(), recalib=self.recalib)
self.q_model.fit(self.e_train)
elif self._q_model_name == "CC":
self.q_model = CC(LogisticRegression())
self.q_model.fit(self.e_train)
def estimate(self, instances, ext=False):
if not ext:
pred_prob = self.c_model.predict_proba(instances)
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
else:
e_inst = instances
estim_prev = self.q_model.quantify(e_inst)
return self._check_prevalence_classes(
self.e_train.classes_, self.q_model, estim_prev
)
def _check_prevalence_classes(self, true_classes, q_model, estim_prev):
if isinstance(q_model, GridSearchQ):
estim_classes = q_model.best_model().classes_
else:
estim_classes = q_model.classes_
for _cls in true_classes:
if _cls not in estim_classes:
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
return estim_prev
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
super().__init__()
self.c_model = c_model
self._q_model_name = q_model.upper()
self.q_models = []
self.gs = gs
self.recalib = recalib
self.e_train = None
def fit(self, train: LabelledCollection | ExtendedCollection):
# check if model is fit
# self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection):
pred_prob_train = cross_val_predict(
self.c_model, *train.Xy, method="predict_proba"
)
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
elif isinstance(train, ExtendedCollection):
self.e_train = train
self.n_classes = self.e_train.n_classes
e_trains = self.e_train.split_by_pred()
if self._q_model_name == "SLD":
fit_scores = []
for e_train in e_trains:
if self.gs:
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
gs_params = self._gs_params(t_val)
q_model = GridSearchQ(
SLD(LogisticRegression()),
**gs_params,
)
q_model.fit(t_train)
fit_scores.append(q_model.best_score_)
self.q_models.append(q_model)
else:
q_model = SLD(LogisticRegression(), recalib=self.recalib)
q_model.fit(e_train)
self.q_models.append(q_model)
if self.gs:
self.fit_score = np.mean(fit_scores)
elif self._q_model_name == "CC":
for e_train in e_trains:
q_model = CC(LogisticRegression())
q_model.fit(e_train)
self.q_models.append(q_model)
def estimate(self, instances, ext=False):
# TODO: test
if not ext:
pred_prob = self.c_model.predict_proba(instances)
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
else:
e_inst = instances
_ncl = int(math.sqrt(self.n_classes))
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
estim_prevs = [
self._quantify_helper(inst, norm, q_model)
for (inst, norm, q_model) in zip(s_inst, norms, self.q_models)
]
estim_prev = []
for prev_row in zip(*estim_prevs):
for prev in prev_row:
estim_prev.append(prev)
return np.asarray(estim_prev)
def _quantify_helper(self, inst, norm, q_model):
if inst.shape[0] > 0:
return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst))))
else:
return np.asarray([0.0, 0.0])

View File

@ -0,0 +1,34 @@
from typing import Callable, Union
import numpy as np
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
import quacc as qc
from ..method.base import BaseAccuracyEstimator
def evaluate(
estimator: BaseAccuracyEstimator,
protocol: AbstractProtocol,
error_metric: Union[Callable | str],
) -> float:
if isinstance(error_metric, str):
error_metric = qc.error.from_name(error_metric)
collator_bck_ = protocol.collator
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
estim_prevs, true_prevs = [], []
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
estim_prevs.append(estim_prev)
true_prevs.append(e_sample.prevalence())
protocol.collator = collator_bck_
true_prevs = np.array(true_prevs)
estim_prevs = np.array(estim_prevs)
return error_metric(true_prevs, estim_prevs)

View File

@ -65,11 +65,10 @@ def ref(
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
):
c_model_predict = getattr(c_model, "predict_proba")
c_model_predict = getattr(c_model, "predict")
report = EvaluationReport(name="ref")
for test in protocol():
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
test_preds = c_model_predict(test.X)
report.append_row(
test.prevalence(),
acc_score=metrics.accuracy_score(test.y, test_preds),

View File

@ -3,6 +3,7 @@ import time
from traceback import print_exception as traceback
from typing import List
import numpy as np
import pandas as pd
import quapy as qp
@ -17,31 +18,63 @@ pd.set_option("display.float_format", "{:.4f}".format)
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
class CompEstimatorName_:
def __init__(self, ce):
self.ce = ce
def __getitem__(self, e: str | List[str]):
if isinstance(e, str):
return self.ce._CompEstimator__get(e)[0]
elif isinstance(e, list):
return list(self.ce._CompEstimator__get(e).keys())
class CompEstimatorFunc_:
def __init__(self, ce):
self.ce = ce
def __getitem__(self, e: str | List[str]):
if isinstance(e, str):
return self.ce._CompEstimator__get(e)[1]
elif isinstance(e, list):
return list(self.ce._CompEstimator__get(e).values())
class CompEstimator:
__dict = method._methods | baseline._baselines
def __class_getitem__(cls, e: str | List[str]):
def __get(cls, e: str | List[str]):
if isinstance(e, str):
try:
return cls.__dict[e]
return (e, cls.__dict[e])
except KeyError:
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
elif isinstance(e, list):
_subtr = [k for k in e if k not in cls.__dict]
_subtr = np.setdiff1d(e, list(cls.__dict.keys()))
if len(_subtr) > 0:
raise KeyError(
f"Invalid estimator: estimator {_subtr[0]} does not exist"
)
return [fun for k, fun in cls.__dict.items() if k in e]
e_fun = {k: fun for k, fun in cls.__dict.items() if k in e}
if "ref" not in e:
e_fun["ref"] = cls.__dict["ref"]
return e_fun
@property
def name(self):
return CompEstimatorName_(self)
@property
def func(self):
return CompEstimatorFunc_(self)
CE = CompEstimator
CE = CompEstimator()
def evaluate_comparison(
dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
) -> EvaluationReport:
def evaluate_comparison(dataset: Dataset, estimators=None) -> EvaluationReport:
log = Logger.logger()
# with multiprocessing.Pool(1) as pool:
with multiprocessing.Pool(len(estimators)) as pool:
@ -52,7 +85,9 @@ def evaluate_comparison(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
)
tstart = time.time()
tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]]
tasks = [
(estim, d.train, d.validation, d.test) for estim in CE.func[estimators]
]
results = [
pool.apply_async(estimate_worker, t, {"_env": env, "q": Logger.queue()})
for t in tasks

View File

@ -1,23 +1,30 @@
import inspect
from functools import wraps
import numpy as np
import sklearn.metrics as metrics
from quapy.data import LabelledCollection
from quapy.protocol import AbstractStochasticSeededProtocol
from sklearn.base import BaseEstimator
from quapy.method.aggregative import PACC, SLD, CC
from quapy.protocol import UPP, AbstractProtocol
from sklearn.linear_model import LogisticRegression
import quacc.error as error
import quacc as qc
from quacc.evaluation.report import EvaluationReport
from quacc.method.model_selection import BQAEgsq, GridSearchAE, MCAEgsq
from ..estimator import (
AccuracyEstimator,
BinaryQuantifierAccuracyEstimator,
MulticlassAccuracyEstimator,
)
from ..method.base import BQAE, MCAE, BaseAccuracyEstimator
_methods = {}
_sld_param_grid = {
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts"],
"q__exact_train_prev": [True],
"confidence": [None, "max_conf", "entropy"],
}
_pacc_param_grid = {
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"confidence": [None, "max_conf", "entropy"],
}
def method(func):
@wraps(func)
def wrapper(c_model, validation, protocol):
@ -28,108 +35,271 @@ def method(func):
return wrapper
def estimate(
estimator: AccuracyEstimator,
protocol: AbstractStochasticSeededProtocol,
):
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
for sample in protocol():
e_sample, pred_proba = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
base_prevs.append(sample.prevalence())
true_prevs.append(e_sample.prevalence())
estim_prevs.append(estim_prev)
pred_probas.append(pred_proba)
labels.append(sample.y)
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
def evaluation_report(
estimator: AccuracyEstimator,
protocol: AbstractStochasticSeededProtocol,
method: str,
estimator: BaseAccuracyEstimator,
protocol: AbstractProtocol,
) -> EvaluationReport:
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
estimator, protocol
)
report = EvaluationReport(name=method)
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
base_prevs, true_prevs, estim_prevs, pred_probas, labels
):
pred = np.argmax(pred_proba, axis=-1)
acc_score = error.acc(estim_prev)
f1_score = error.f1(estim_prev)
method_name = inspect.stack()[1].function
report = EvaluationReport(name=method_name)
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
acc_score = qc.error.acc(estim_prev)
f1_score = qc.error.f1(estim_prev)
report.append_row(
base_prev,
sample.prevalence(),
acc_score=acc_score,
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
acc=abs(qc.error.acc(e_sample.prevalence()) - acc_score),
f1_score=f1_score,
f1=abs(error.f1(true_prev) - f1_score),
f1=abs(qc.error.f1(e_sample.prevalence()) - f1_score),
)
report.fit_score = estimator.fit_score
return report
def evaluate(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
method: str,
q_model: str,
**kwargs,
):
estimator: AccuracyEstimator = {
"bin": BinaryQuantifierAccuracyEstimator,
"mul": MulticlassAccuracyEstimator,
}[method](c_model, q_model=q_model.upper(), **kwargs)
estimator.fit(validation)
_method = f"{method}_{q_model}"
if "recalib" in kwargs:
_method += f"_{kwargs['recalib']}"
if ("gs", True) in kwargs.items():
_method += "_gs"
return evaluation_report(estimator, protocol, _method)
@method
def bin_sld(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "sld")
est = BQAE(c_model, SLD(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_sld(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "sld")
est = MCAE(c_model, SLD(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "sld", recalib="bcts")
def binmc_sld(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(
c_model,
SLD(LogisticRegression()),
confidence="max_conf",
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "sld", recalib="bcts")
def mulmc_sld(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(
c_model,
SLD(LogisticRegression()),
confidence="max_conf",
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def binne_sld(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(
c_model,
SLD(LogisticRegression()),
confidence="entropy",
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mulne_sld(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(
c_model,
SLD(LogisticRegression()),
confidence="entropy",
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "sld", gs=True)
v_train, v_val = validation.split_stratified(0.6, random_state=0)
model = BQAE(c_model, SLD(LogisticRegression()))
est = GridSearchAE(
model=model,
param_grid=_sld_param_grid,
refit=False,
protocol=UPP(v_val, repeats=100),
verbose=True,
).fit(v_train)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "sld", gs=True)
v_train, v_val = validation.split_stratified(0.6, random_state=0)
model = MCAE(c_model, SLD(LogisticRegression()))
est = GridSearchAE(
model=model,
param_grid=_sld_param_grid,
refit=False,
protocol=UPP(v_val, repeats=100),
verbose=True,
).fit(v_train)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def bin_sld_gsq(c_model, validation, protocol) -> EvaluationReport:
est = BQAEgsq(
c_model,
SLD(LogisticRegression()),
param_grid={
"classifier__C": np.logspace(-3, 3, 7),
"classifier__class_weight": [None, "balanced"],
"recalib": [None, "bcts", "vs"],
},
refit=False,
verbose=False,
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_sld_gsq(c_model, validation, protocol) -> EvaluationReport:
est = MCAEgsq(
c_model,
SLD(LogisticRegression()),
param_grid={
"classifier__C": np.logspace(-3, 3, 7),
"classifier__class_weight": [None, "balanced"],
"recalib": [None, "bcts", "vs"],
},
refit=False,
verbose=False,
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def bin_pacc(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(c_model, PACC(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_pacc(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(c_model, PACC(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def binmc_pacc(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(c_model, PACC(LogisticRegression()), confidence="max_conf").fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mulmc_pacc(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(c_model, PACC(LogisticRegression()), confidence="max_conf").fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def binne_pacc(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(c_model, PACC(LogisticRegression()), confidence="entropy").fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mulne_pacc(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(c_model, PACC(LogisticRegression()), confidence="entropy").fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def bin_pacc_gs(c_model, validation, protocol) -> EvaluationReport:
v_train, v_val = validation.split_stratified(0.6, random_state=0)
model = BQAE(c_model, PACC(LogisticRegression()))
est = GridSearchAE(
model=model,
param_grid=_pacc_param_grid,
refit=False,
protocol=UPP(v_val, repeats=100),
verbose=False,
).fit(v_train)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_pacc_gs(c_model, validation, protocol) -> EvaluationReport:
v_train, v_val = validation.split_stratified(0.6, random_state=0)
model = MCAE(c_model, PACC(LogisticRegression()))
est = GridSearchAE(
model=model,
param_grid=_pacc_param_grid,
refit=False,
protocol=UPP(v_val, repeats=100),
verbose=False,
).fit(v_train)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def bin_cc(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "cc")
est = BQAE(c_model, CC(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def mul_cc(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "cc")
est = MCAE(c_model, CC(LogisticRegression())).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)

View File

@ -115,7 +115,7 @@ class CompReport:
shift_data = self._data.copy()
shift_data.index = pd.MultiIndex.from_arrays([shift_idx_0, shift_idx_1])
shift_data.sort_index(axis=0, level=0)
shift_data = shift_data.sort_index(axis=0, level=0)
_metric = _get_metric(metric)
_estimators = _get_estimators(estimators, shift_data.columns.unique(1))
@ -182,22 +182,21 @@ class CompReport:
train_prev=self.train_prev,
)
elif mode == "shift":
shift_data = (
self.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
_shift_data = self.shift_data(metric=metric, estimators=estimators)
shift_avg = _shift_data.groupby(level=0).mean()
shift_counts = _shift_data.groupby(level=0).count()
shift_prevs = np.around(
[(1.0 - p, p) for p in np.sort(shift_data.index.unique(0))],
[(1.0 - p, p) for p in np.sort(shift_avg.index.unique(0))],
decimals=2,
)
return plot.plot_shift(
shift_prevs=shift_prevs,
columns=shift_data.columns.to_numpy(),
data=shift_data.T.to_numpy(),
columns=shift_avg.columns.to_numpy(),
data=shift_avg.T.to_numpy(),
metric=metric,
name=conf,
train_prev=self.train_prev,
counts=shift_counts.T.to_numpy(),
)
def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str:
@ -246,7 +245,7 @@ class DatasetReport:
)
_crs_train, _crs_data = zip(*_crs_sorted)
_data = pd.concat(_crs_data, axis=0, keys=_crs_train)
_data = pd.concat(_crs_data, axis=0, keys=np.around(_crs_train, decimals=2))
_data = _data.sort_index(axis=0, level=0)
return _data
@ -296,47 +295,95 @@ class DatasetReport:
_data = self.data(metric=metric, estimators=estimators)
_shift_data = self.shift_data(metric=metric, estimators=estimators)
avg_x_test = _data.groupby(level=1).mean()
prevs_x_test = np.sort(avg_x_test.index.unique(0))
stdev_x_test = _data.groupby(level=1).std() if stdev else None
avg_x_test_tbl = _data.groupby(level=1).mean()
avg_x_test_tbl.loc["avg", :] = _data.mean()
avg_x_shift = _shift_data.groupby(level=0).mean()
prevs_x_shift = np.sort(avg_x_shift.index.unique(0))
res += "## avg\n"
res += avg_x_test_tbl.to_html() + "\n\n"
######################## avg on train ########################
res += "### avg on train\n"
avg_on_train = _data.groupby(level=1).mean()
prevs_on_train = np.sort(avg_on_train.index.unique(0))
stdev_on_train = _data.groupby(level=1).std() if stdev else None
avg_on_train_tbl = _data.groupby(level=1).mean()
avg_on_train_tbl.loc["avg", :] = _data.mean()
res += avg_on_train_tbl.to_html() + "\n\n"
delta_op = plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_x_test], decimals=2),
columns=avg_x_test.columns.to_numpy(),
data=avg_x_test.T.to_numpy(),
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_train], decimals=2),
columns=avg_on_train.columns.to_numpy(),
data=avg_on_train.T.to_numpy(),
metric=metric,
name=conf,
train_prev=None,
avg="train",
)
res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n"
if stdev:
delta_stdev_op = plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_x_test], decimals=2),
columns=avg_x_test.columns.to_numpy(),
data=avg_x_test.T.to_numpy(),
base_prevs=np.around(
[(1.0 - p, p) for p in prevs_on_train], decimals=2
),
columns=avg_on_train.columns.to_numpy(),
data=avg_on_train.T.to_numpy(),
metric=metric,
name=conf,
train_prev=None,
stdevs=stdev_x_test.T.to_numpy(),
stdevs=stdev_on_train.T.to_numpy(),
avg="train",
)
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n"
shift_op = plot.plot_shift(
shift_prevs=np.around([(1.0 - p, p) for p in prevs_x_shift], decimals=2),
columns=avg_x_shift.columns.to_numpy(),
data=avg_x_shift.T.to_numpy(),
######################## avg on test ########################
res += "### avg on test\n"
avg_on_test = _data.groupby(level=0).mean()
prevs_on_test = np.sort(avg_on_test.index.unique(0))
stdev_on_test = _data.groupby(level=0).std() if stdev else None
avg_on_test_tbl = _data.groupby(level=0).mean()
avg_on_test_tbl.loc["avg", :] = _data.mean()
res += avg_on_test_tbl.to_html() + "\n\n"
delta_op = plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
columns=avg_on_test.columns.to_numpy(),
data=avg_on_test.T.to_numpy(),
metric=metric,
name=conf,
train_prev=None,
avg="test",
)
res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n"
if stdev:
delta_stdev_op = plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
columns=avg_on_test.columns.to_numpy(),
data=avg_on_test.T.to_numpy(),
metric=metric,
name=conf,
train_prev=None,
stdevs=stdev_on_test.T.to_numpy(),
avg="test",
)
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n"
######################## avg shift ########################
res += "### avg dataset shift\n"
avg_shift = _shift_data.groupby(level=0).mean()
count_shift = _shift_data.groupby(level=0).count()
prevs_shift = np.sort(avg_shift.index.unique(0))
shift_op = plot.plot_shift(
shift_prevs=np.around([(1.0 - p, p) for p in prevs_shift], decimals=2),
columns=avg_shift.columns.to_numpy(),
data=avg_shift.T.to_numpy(),
metric=metric,
name=conf,
train_prev=None,
counts=count_shift.T.to_numpy(),
)
res += f"![plot_shift]({shift_op.relative_to(env.OUT_DIR).as_posix()})\n"

View File

@ -27,7 +27,7 @@ def estimate_worker(_estimate, train, validation, test, _env=None, q=None):
result = _estimate(model, validation, protocol)
except Exception as e:
log.warning(f"Method {_estimate.__name__} failed. Exception: {e}")
# traceback(e)
traceback(e)
return {
"name": _estimate.__name__,
"result": None,

View File

@ -2,6 +2,7 @@ import logging
import logging.handlers
import multiprocessing
import threading
from pathlib import Path
class Logger:
@ -11,6 +12,7 @@ class Logger:
__queue = None
__thread = None
__setup = False
__handlers = []
@classmethod
def __logger_listener(cls, q):
@ -32,7 +34,6 @@ class Logger:
rh = logging.FileHandler(cls.__logger_file, mode="a")
rh.setLevel(logging.DEBUG)
root.addHandler(rh)
root.info("-" * 100)
# setup logger
if cls.__manager is None:
@ -62,6 +63,21 @@ class Logger:
cls.__setup = True
@classmethod
def add_handler(cls, path: Path):
root = logging.getLogger("listener")
rh = logging.FileHandler(path, mode="a")
rh.setLevel(logging.DEBUG)
cls.__handlers.append(rh)
root.addHandler(rh)
@classmethod
def clear_handlers(cls):
root = logging.getLogger("listener")
for h in cls.__handlers:
root.removeHandler(h)
cls.__handlers.clear()
@classmethod
def queue(cls):
if not cls.__setup:
@ -79,6 +95,8 @@ class Logger:
@classmethod
def close(cls):
if cls.__setup and cls.__thread is not None:
root = logging.getLogger("listener")
root.info("-" * 100)
cls.__queue.put(None)
cls.__thread.join()
# cls.__manager.close()
@ -102,7 +120,7 @@ class SubLogger:
rh.setLevel(logging.DEBUG)
rh.setFormatter(
logging.Formatter(
fmt="%(asctime)s| %(levelname)-12s\t%(message)s",
fmt="%(asctime)s| %(levelname)-12s%(message)s",
datefmt="%d/%m/%y %H:%M:%S",
)
)

View File

@ -7,6 +7,8 @@ from quacc.environment import env
from quacc.logger import Logger
from quacc.utils import create_dataser_dir
CE = comp.CompEstimator()
def toast():
if platform == "win32":
@ -25,8 +27,12 @@ def estimate_comparison():
prevs=env.DATASET_PREVS,
)
create_dataser_dir(dataset.name, update=env.DATASET_DIR_UPDATE)
Logger.add_handler(env.OUT_DIR / f"{dataset.name}.log")
try:
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
dr = comp.evaluate_comparison(
dataset,
estimators=CE.name[env.COMP_ESTIMATORS],
)
except Exception as e:
log.error(f"Evaluation over {dataset.name} failed. Exception: {e}")
traceback(e)
@ -37,7 +43,7 @@ def estimate_comparison():
_repr = dr.to_md(
conf=plot_conf,
metric=m,
estimators=env.PLOT_ESTIMATORS,
estimators=CE.name[env.PLOT_ESTIMATORS],
stdev=env.PLOT_STDEV,
)
with open(output_path, "w") as f:
@ -47,6 +53,7 @@ def estimate_comparison():
f"Failed while saving configuration {plot_conf} of {dataset.name}. Exception: {e}"
)
traceback(e)
Logger.clear_handlers()
# print(df.to_latex(float_format="{:.4f}".format))
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))

120
quacc/main_test.py Normal file
View File

@ -0,0 +1,120 @@
from copy import deepcopy
from time import time
import numpy as np
import win11toast
from quapy.method.aggregative import SLD
from quapy.protocol import APP, UPP
from sklearn.linear_model import LogisticRegression
import quacc as qc
from quacc.dataset import Dataset
from quacc.error import acc
from quacc.evaluation.baseline import ref
from quacc.evaluation.method import mulmc_sld
from quacc.evaluation.report import CompReport, EvaluationReport
from quacc.method.base import MCAE, BinaryQuantifierAccuracyEstimator
from quacc.method.model_selection import GridSearchAE
def test_gs():
d = Dataset(name="rcv1", target="CCAT", n_prevalences=1).get_raw()
classifier = LogisticRegression()
classifier.fit(*d.train.Xy)
quantifier = SLD(LogisticRegression())
# estimator = MultiClassAccuracyEstimator(classifier, quantifier)
estimator = BinaryQuantifierAccuracyEstimator(classifier, quantifier)
v_train, v_val = d.validation.split_stratified(0.6, random_state=0)
gs_protocol = UPP(v_val, sample_size=1000, repeats=100)
gs_estimator = GridSearchAE(
model=deepcopy(estimator),
param_grid={
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts", "ts"],
},
refit=False,
protocol=gs_protocol,
verbose=True,
).fit(v_train)
estimator.fit(d.validation)
tstart = time()
erb, ergs = EvaluationReport("base"), EvaluationReport("gs")
protocol = APP(
d.test,
sample_size=1000,
n_prevalences=21,
repeats=100,
return_type="labelled_collection",
)
for sample in protocol():
e_sample = gs_estimator.extend(sample)
estim_prev_b = estimator.estimate(e_sample.X, ext=True)
estim_prev_gs = gs_estimator.estimate(e_sample.X, ext=True)
erb.append_row(
sample.prevalence(),
acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_b)),
)
ergs.append_row(
sample.prevalence(),
acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_gs)),
)
cr = CompReport(
[erb, ergs],
"test",
train_prev=d.train_prev,
valid_prev=d.validation_prev,
)
print(cr.table())
print(f"[took {time() - tstart:.3f}s]")
win11toast.notify("Test", "completed")
def test_mc():
d = Dataset(name="rcv1", target="CCAT", prevs=[0.9]).get()[0]
classifier = LogisticRegression().fit(*d.train.Xy)
protocol = APP(
d.test,
sample_size=1000,
repeats=100,
n_prevalences=21,
return_type="labelled_collection",
)
ref_er = ref(classifier, d.validation, protocol)
mulmc_er = mulmc_sld(classifier, d.validation, protocol)
cr = CompReport(
[mulmc_er, ref_er],
name="test_mc",
train_prev=d.train_prev,
valid_prev=d.validation_prev,
)
with open("test_mc.md", "w") as f:
f.write(cr.data().to_markdown())
def test_et():
d = Dataset(name="imdb", prevs=[0.5]).get()[0]
classifier = LogisticRegression().fit(*d.train.Xy)
estimator = MCAE(
classifier,
SLD(LogisticRegression(), exact_train_prev=False),
confidence="max_conf",
).fit(d.validation)
e_test = estimator.extend(d.test)
ep = estimator.estimate(e_test.X, ext=True)
print(f"{qc.error.acc(ep) = }")
print(f"{qc.error.acc(e_test.prevalence()) = }")
if __name__ == "__main__":
test_et()

177
quacc/method/base.py Normal file
View File

@ -0,0 +1,177 @@
import math
from abc import abstractmethod
from copy import deepcopy
from typing import List
import numpy as np
from quapy.data import LabelledCollection
from quapy.method.aggregative import BaseQuantifier
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator
from quacc.data import ExtendedCollection
class BaseAccuracyEstimator(BaseQuantifier):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
confidence=None,
):
self.__check_classifier(classifier)
self.quantifier = quantifier
self.confidence = confidence
def __check_classifier(self, classifier):
if not hasattr(classifier, "predict_proba"):
raise ValueError(
f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities."
)
self.classifier = classifier
def __get_confidence(self):
def max_conf(probas):
_mc = np.max(probas, axis=-1)
_min = 1.0 / probas.shape[1]
_norm_mc = (_mc - _min) / (1.0 - _min)
return _norm_mc
def entropy(probas):
_ent = np.sum(np.multiply(probas, np.log(probas + 1e-20)), axis=1)
return _ent
if self.confidence is None:
return None
__confs = {
"max_conf": max_conf,
"entropy": entropy,
}
return __confs.get(self.confidence, None)
def __get_ext(self, pred_proba):
_ext = pred_proba
_f_conf = self.__get_confidence()
if _f_conf is not None:
_confs = _f_conf(pred_proba).reshape((len(pred_proba), 1))
_ext = np.concatenate((_confs, pred_proba), axis=1)
return _ext
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if pred_proba is None:
pred_proba = self.classifier.predict_proba(coll.X)
_ext = self.__get_ext(pred_proba)
return ExtendedCollection.extend_collection(coll, pred_proba=_ext)
def _extend_instances(self, instances: np.ndarray | csr_matrix, pred_proba=None):
if pred_proba is None:
pred_proba = self.classifier.predict_proba(instances)
_ext = self.__get_ext(pred_proba)
return ExtendedCollection.extend_instances(instances, _ext)
@abstractmethod
def fit(self, train: LabelledCollection | ExtendedCollection):
...
@abstractmethod
def estimate(self, instances, ext=False) -> np.ndarray:
...
class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
confidence: str = None,
):
super().__init__(
classifier=classifier,
quantifier=quantifier,
confidence=confidence,
)
self.e_train = None
def fit(self, train: LabelledCollection):
self.e_train = self.extend(train)
self.quantifier.fit(self.e_train)
return self
def estimate(self, instances, ext=False) -> np.ndarray:
e_inst = instances if ext else self._extend_instances(instances)
estim_prev = self.quantifier.quantify(e_inst)
return self._check_prevalence_classes(estim_prev, self.quantifier.classes_)
def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray:
true_classes = self.e_train.classes_
for _cls in true_classes:
if _cls not in estim_classes:
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
return estim_prev
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
confidence: str = None,
):
super().__init__(
classifier=classifier,
quantifier=quantifier,
confidence=confidence,
)
self.quantifiers = []
self.e_trains = []
def fit(self, train: LabelledCollection | ExtendedCollection):
self.e_train = self.extend(train)
self.n_classes = self.e_train.n_classes
self.e_trains = self.e_train.split_by_pred()
self.quantifiers = []
for train in self.e_trains:
quant = deepcopy(self.quantifier)
quant.fit(train)
self.quantifiers.append(quant)
return self
def estimate(self, instances, ext=False):
# TODO: test
e_inst = instances if ext else self._extend_instances(instances)
_ncl = int(math.sqrt(self.n_classes))
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
estim_prevs = self._quantify_helper(s_inst, norms)
estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten()
return estim_prev
def _quantify_helper(
self,
s_inst: List[np.ndarray | csr_matrix],
norms: List[float],
):
estim_prevs = []
for quant, inst, norm in zip(self.quantifiers, s_inst, norms):
if inst.shape[0] > 0:
estim_prevs.append(quant.quantify(inst) * norm)
else:
estim_prevs.append(np.asarray([0.0, 0.0]))
return estim_prevs
BAE = BaseAccuracyEstimator
MCAE = MultiClassAccuracyEstimator
BQAE = BinaryQuantifierAccuracyEstimator

View File

@ -0,0 +1,307 @@
import itertools
from copy import deepcopy
from time import time
from typing import Callable, Union
import numpy as np
import quapy as qp
from quapy.data import LabelledCollection
from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
from sklearn.base import BaseEstimator
import quacc as qc
import quacc.error
from quacc.data import ExtendedCollection
from quacc.evaluation import evaluate
from quacc.logger import SubLogger
from quacc.method.base import (
BaseAccuracyEstimator,
BinaryQuantifierAccuracyEstimator,
MultiClassAccuracyEstimator,
)
class GridSearchAE(BaseAccuracyEstimator):
def __init__(
self,
model: BaseAccuracyEstimator,
param_grid: dict,
protocol: AbstractProtocol,
error: Union[Callable, str] = qc.error.maccd,
refit=True,
# timeout=-1,
# n_jobs=None,
verbose=False,
):
self.model = model
self.param_grid = self.__normalize_params(param_grid)
self.protocol = protocol
self.refit = refit
# self.timeout = timeout
# self.n_jobs = qp._get_njobs(n_jobs)
self.verbose = verbose
self.__check_error(error)
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
def _sout(self, msg):
if self.verbose:
print(f"[{self.__class__.__name__}]: {msg}")
def __normalize_params(self, params):
__remap = {}
for key in params.keys():
k, delim, sub_key = key.partition("__")
if delim and k == "q":
__remap[key] = f"quantifier__{sub_key}"
return {(__remap[k] if k in __remap else k): v for k, v in params.items()}
def __check_error(self, error):
if error in qc.error.ACCURACY_ERROR:
self.error = error
elif isinstance(error, str):
self.error = qc.error.from_name(error)
elif hasattr(error, "__call__"):
self.error = error
else:
raise ValueError(
f"unexpected error type; must either be a callable function or a str representing\n"
f"the name of an error function in {qc.error.ACCURACY_ERROR_NAMES}"
)
def fit(self, training: LabelledCollection):
"""Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
the error metric.
:param training: the training set on which to optimize the hyperparameters
:return: self
"""
params_keys = list(self.param_grid.keys())
params_values = list(self.param_grid.values())
protocol = self.protocol
self.param_scores_ = {}
self.best_score_ = None
tinit = time()
hyper = [
dict(zip(params_keys, val)) for val in itertools.product(*params_values)
]
# self._sout(f"starting model selection with {self.n_jobs =}")
self._sout("starting model selection")
scores = [self.__params_eval(params, training) for params in hyper]
for params, score, model in scores:
if score is not None:
if self.best_score_ is None or score < self.best_score_:
self.best_score_ = score
self.best_params_ = params
self.best_model_ = model
self.param_scores_[str(params)] = score
else:
self.param_scores_[str(params)] = "timeout"
tend = time() - tinit
if self.best_score_ is None:
raise TimeoutError("no combination of hyperparameters seem to work")
self._sout(
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
f"[took {tend:.4f}s]"
)
log = SubLogger.logger()
log.debug(
f"[{self.model.__class__.__name__}] "
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
f"[took {tend:.4f}s]"
)
if self.refit:
if isinstance(protocol, OnLabelledCollectionProtocol):
self._sout("refitting on the whole development set")
self.best_model_.fit(training + protocol.get_labelled_collection())
else:
raise RuntimeWarning(
f'"refit" was requested, but the protocol does not '
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
)
return self
def __params_eval(self, params, training):
protocol = self.protocol
error = self.error
# if self.timeout > 0:
# def handler(signum, frame):
# raise TimeoutError()
# signal.signal(signal.SIGALRM, handler)
tinit = time()
# if self.timeout > 0:
# signal.alarm(self.timeout)
try:
model = deepcopy(self.model)
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**params)
# print({k: v for k, v in model.get_params().items() if k in params})
model.fit(training)
score = evaluate(model, protocol=protocol, error_metric=error)
ttime = time() - tinit
self._sout(
f"hyperparams={params}\t got score {score:.5f} [took {ttime:.4f}s]"
)
# if self.timeout > 0:
# signal.alarm(0)
# except TimeoutError:
# self._sout(f"timeout ({self.timeout}s) reached for config {params}")
# score = None
except ValueError as e:
self._sout(f"the combination of hyperparameters {params} is invalid")
raise e
except Exception as e:
self._sout(f"something went wrong for config {params}; skipping:")
self._sout(f"\tException: {e}")
score = None
return params, score, model
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
assert hasattr(self, "best_model_"), "quantify called before fit"
return self.best_model().extend(coll, pred_proba=pred_proba)
def estimate(self, instances, ext=False):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
:param instances: sample contanining the instances
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process.
"""
assert hasattr(self, "best_model_"), "estimate called before fit"
return self.best_model().estimate(instances, ext=ext)
def set_params(self, **parameters):
"""Sets the hyper-parameters to explore.
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
"""
self.param_grid = parameters
def get_params(self, deep=True):
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
:param deep: Unused
:return: the dictionary `param_grid`
"""
return self.param_grid
def best_model(self):
"""
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
of hyper-parameters that minimized the error function.
:return: a trained quantifier
"""
if hasattr(self, "best_model_"):
return self.best_model_
raise ValueError("best_model called before fit")
class MCAEgsq(MultiClassAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
param_grid: dict,
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
verbose=False,
):
self.param_grid = param_grid
self.refit = refit
self.timeout = timeout
self.n_jobs = n_jobs
self.verbose = verbose
self.error = error
super().__init__(classifier, quantifier)
def fit(self, train: LabelledCollection):
self.e_train = self.extend(train)
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
self.quantifier = GridSearchQ(
deepcopy(self.quantifier),
param_grid=self.param_grid,
protocol=UPP(t_val, repeats=100),
error=self.error,
refit=self.refit,
timeout=self.timeout,
n_jobs=self.n_jobs,
verbose=self.verbose,
).fit(self.e_train)
return self
def estimate(self, instances, ext=False) -> np.ndarray:
e_inst = instances if ext else self._extend_instances(instances)
estim_prev = self.quantifier.quantify(e_inst)
return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_)
class BQAEgsq(BinaryQuantifierAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
param_grid: dict,
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
verbose=False,
):
self.param_grid = param_grid
self.refit = refit
self.timeout = timeout
self.n_jobs = n_jobs
self.verbose = verbose
self.error = error
super().__init__(classifier=classifier, quantifier=quantifier)
def fit(self, train: LabelledCollection):
self.e_train = self.extend(train)
self.n_classes = self.e_train.n_classes
self.e_trains = self.e_train.split_by_pred()
self.quantifiers = []
for e_train in self.e_trains:
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
quantifier = GridSearchQ(
model=deepcopy(self.quantifier),
param_grid=self.param_grid,
protocol=UPP(t_val, repeats=100),
error=self.error,
refit=self.refit,
timeout=self.timeout,
n_jobs=self.n_jobs,
verbose=self.verbose,
).fit(t_train)
self.quantifiers.append(quantifier)
return self

View File

@ -1,138 +0,0 @@
import numpy as np
import scipy as sp
import quapy as qp
from quapy.data import LabelledCollection
from quapy.method.aggregative import SLD
from quapy.protocol import APP, AbstractStochasticSeededProtocol
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from .data import get_dataset
# Extended classes
#
# 0 ~ True 0
# 1 ~ False 1
# 2 ~ False 0
# 3 ~ True 1
# _____________________
# | | |
# | True 0 | False 1 |
# |__________|__________|
# | | |
# | False 0 | True 1 |
# |__________|__________|
#
def get_ex_class(classes, true_class, pred_class):
return true_class * classes + pred_class
def extend_collection(coll, pred_prob):
n_classes = coll.n_classes
# n_X = [ X | predicted probs. ]
if isinstance(coll.X, sp.csr_matrix):
pred_prob_csr = sp.csr_matrix(pred_prob)
n_x = sp.hstack([coll.X, pred_prob_csr])
elif isinstance(coll.X, np.ndarray):
n_x = np.concatenate((coll.X, pred_prob), axis=1)
else:
raise ValueError("Unsupported matrix format")
# n_y = (exptected y, predicted y)
n_y = []
for i, true_class in enumerate(coll.y):
pred_class = pred_prob[i].argmax(axis=0)
n_y.append(get_ex_class(n_classes, true_class, pred_class))
return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)])
def qf1e_binary(prev):
recall = prev[0] / (prev[0] + prev[1])
precision = prev[0] / (prev[0] + prev[2])
return 1 - 2 * (precision * recall) / (precision + recall)
def compute_errors(true_prev, estim_prev, n_instances):
errors = {}
_eps = 1 / (2 * n_instances)
errors = {
"mae": qp.error.mae(true_prev, estim_prev),
"rae": qp.error.rae(true_prev, estim_prev, eps=_eps),
"mrae": qp.error.mrae(true_prev, estim_prev, eps=_eps),
"kld": qp.error.kld(true_prev, estim_prev, eps=_eps),
"nkld": qp.error.nkld(true_prev, estim_prev, eps=_eps),
"true_f1e": qf1e_binary(true_prev),
"estim_f1e": qf1e_binary(estim_prev),
}
return errors
def extend_and_quantify(
model,
q_model,
train,
test: LabelledCollection | AbstractStochasticSeededProtocol,
):
model.fit(*train.Xy)
pred_prob_train = cross_val_predict(model, *train.Xy, method="predict_proba")
_train = extend_collection(train, pred_prob_train)
q_model.fit(_train)
def quantify_extended(test):
pred_prob_test = model.predict_proba(test.X)
_test = extend_collection(test, pred_prob_test)
_estim_prev = q_model.quantify(_test.instances)
# check that _estim_prev has all the classes and eventually fill the missing
# ones with 0
for _cls in _test.classes_:
if _cls not in q_model.classes_:
_estim_prev = np.insert(_estim_prev, _cls, [0.0], axis=0)
print(_estim_prev)
return _test.prevalence(), _estim_prev
if isinstance(test, LabelledCollection):
_true_prev, _estim_prev = quantify_extended(test)
_errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0])
return ([test.prevalence()], [_true_prev], [_estim_prev], [_errors])
elif isinstance(test, AbstractStochasticSeededProtocol):
orig_prevs, true_prevs, estim_prevs, errors = [], [], [], []
for index in test.samples_parameters():
sample = test.sample(index)
_true_prev, _estim_prev = quantify_extended(sample)
orig_prevs.append(sample.prevalence())
true_prevs.append(_true_prev)
estim_prevs.append(_estim_prev)
errors.append(compute_errors(_true_prev, _estim_prev, sample.X.shape[0]))
return orig_prevs, true_prevs, estim_prevs, errors
def test_1(dataset_name):
train, test = get_dataset(dataset_name)
orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify(
LogisticRegression(),
SLD(LogisticRegression()),
train,
APP(test, n_prevalences=11, repeats=1),
)
for orig_prev, true_prev, estim_prev, _errors in zip(
orig_prevs, true_prevs, estim_prevs, errors
):
print(f"original prevalence:\t{orig_prev}")
print(f"true prevalence:\t{true_prev}")
print(f"estimated prevalence:\t{estim_prev}")
for name, err in _errors.items():
print(f"{name}={err:.3f}")
print()

View File

@ -27,15 +27,15 @@ def plot_delta(
metric="acc",
name="default",
train_prev=None,
fit_scores=None,
legend=True,
avg=None,
) -> Path:
_base_title = "delta_stdev" if stdevs is not None else "delta"
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
else:
title = f"{_base_title}_{name}_avg_{metric}"
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
fig, ax = plt.subplots()
ax.set_aspect("auto")
@ -74,16 +74,13 @@ def plot_delta(
color=_cy["color"],
alpha=0.25,
)
if fit_scores is not None and method in fit_scores:
ax.plot(
base_prevs,
np.repeat(fit_scores[method], base_prevs.shape[0]),
color=_cy["color"],
linestyle="--",
markersize=0,
)
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
x_label = "test" if avg is None or avg == "train" else "train"
ax.set(
xlabel=f"{x_label} prevalence",
ylabel=metric,
title=title,
)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
@ -182,11 +179,11 @@ def plot_shift(
columns,
data,
*,
counts=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
fit_scores=None,
legend=True,
) -> Path:
if train_prev is not None:
@ -217,15 +214,20 @@ def plot_shift(
markersize=3,
zorder=2,
)
if fit_scores is not None and method in fit_scores:
ax.plot(
shift_prevs,
np.repeat(fit_scores[method], shift_prevs.shape[0]),
color=_cy["color"],
linestyle="--",
markersize=0,
)
if counts is not None:
_col_idx = np.where(columns == method)[0]
count = counts[_col_idx].flatten()
for prev, shift, cnt in zip(shift_prevs, shifts, count):
label = f"{cnt}"
plt.annotate(
label,
(prev, shift),
textcoords="offset points",
xytext=(0, 10),
ha="center",
color=_cy["color"],
fontsize=12.0,
)
ax.set(xlabel="dataset shift", ylabel=metric, title=title)

2102
test_mc.md Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +0,0 @@
from sklearn.linear_model import LogisticRegression
from quacc.evaluation.baseline import kfcv, trust_score
from quacc.dataset import get_spambase
class TestBaseline:
def test_kfcv(self):
train, validation, _ = get_spambase()
c_model = LogisticRegression()
c_model.fit(train.X, train.y)
assert "f1_score" in kfcv(c_model, validation)
def test_trust_score(self):
train, validation, test = get_spambase()
c_model = LogisticRegression()
c_model.fit(train.X, train.y)
trustscore = trust_score(c_model, train, test)
assert len(trustscore) == len(test.y)

View File

@ -0,0 +1,12 @@
from sklearn.linear_model import LogisticRegression
from quacc.dataset import Dataset
from quacc.evaluation.baseline import kfcv
class TestBaseline:
def test_kfcv(self):
spambase = Dataset("spambase", n_prevalences=1).get_raw()
c_model = LogisticRegression()
c_model.fit(spambase.train.X, spambase.train.y)
assert "f1_score" in kfcv(c_model, spambase.validation)

View File

@ -1,12 +1,12 @@
import pytest
import numpy as np
import pytest
import scipy.sparse as sp
from sklearn.linear_model import LogisticRegression
from quacc.estimator import BinaryQuantifierAccuracyEstimator
from quacc.method.base import BinaryQuantifierAccuracyEstimator
class TestBinaryQuantifierAccuracyEstimator:
class TestBQAE:
@pytest.mark.parametrize(
"instances,preds0,preds1,result",
[

View File

@ -0,0 +1,2 @@
class TestMCAE:
pass