From 7ad5311fac7fe0ec529b2eef0e1e17899749f5a4 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Fri, 14 Nov 2025 12:07:32 +0100 Subject: [PATCH] merging from devel --- BayesianKDEy/bayesian_kdey.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/BayesianKDEy/bayesian_kdey.py b/BayesianKDEy/bayesian_kdey.py index 419b974..11aba80 100644 --- a/BayesianKDEy/bayesian_kdey.py +++ b/BayesianKDEy/bayesian_kdey.py @@ -1,6 +1,7 @@ from sklearn.linear_model import LogisticRegression import quapy as qp from BayesianKDEy.plot_simplex import plot_prev_points, plot_prev_points_matplot +from method.confidence import ConfidenceIntervals from quapy.functional import strprev from quapy.method.aggregative import KDEyML from quapy.protocol import UPP @@ -80,14 +81,14 @@ if __name__ == '__main__': cls = LogisticRegression() kdey = KDEyML(cls) - train, test = qp.datasets.fetch_UCIMulticlassDataset('waveform-v1', standardize=True).train_test - # train, test = qp.datasets.fetch_UCIMulticlassDataset('phishing', standardize=True).train_test + train, test = qp.datasets.fetch_UCIMulticlassDataset('dry-bean', standardize=True).train_test with qp.util.temp_seed(2): print('fitting KDEy') kdey.fit(*train.Xy) - shifted = test.sampling(500, *[0.7, 0.1, 0.2]) + # shifted = test.sampling(500, *[0.7, 0.1, 0.2]) + shifted = test.sampling(500, *test.prevalence()[::-1]) prev_hat = kdey.predict(shifted.X) mae = qp.error.mae(shifted.prevalence(), prev_hat) print(f'true_prev={strprev(shifted.prevalence())}, prev_hat={strprev(prev_hat)}, {mae=:.4f}') @@ -97,9 +98,12 @@ if __name__ == '__main__': samples = bayesian(kdes, shifted.X, h, init=None, MAX_ITER=5_000, warmup=1_000) print(f'mean posterior {strprev(samples.mean(axis=0))}') + conf_interval = ConfidenceIntervals(samples, confidence_level=0.95) + print() - plot_prev_points(samples, true_prev=shifted.prevalence(), point_estim=prev_hat, train_prev=train.prevalence()) - # plot_prev_points_matplot(samples) + if train.n_classes == 3: + plot_prev_points(samples, true_prev=shifted.prevalence(), point_estim=prev_hat, train_prev=train.prevalence()) + # plot_prev_points_matplot(samples) # report = qp.evaluation.evaluation_report(kdey, protocol=UPP(test), verbose=True)