merging from devel

This commit is contained in:
Alejandro Moreo Fernandez 2025-11-14 12:07:32 +01:00
parent 400edfdb63
commit 7ad5311fac
1 changed files with 9 additions and 5 deletions

View File

@ -1,6 +1,7 @@
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
import quapy as qp import quapy as qp
from BayesianKDEy.plot_simplex import plot_prev_points, plot_prev_points_matplot from BayesianKDEy.plot_simplex import plot_prev_points, plot_prev_points_matplot
from method.confidence import ConfidenceIntervals
from quapy.functional import strprev from quapy.functional import strprev
from quapy.method.aggregative import KDEyML from quapy.method.aggregative import KDEyML
from quapy.protocol import UPP from quapy.protocol import UPP
@ -80,14 +81,14 @@ if __name__ == '__main__':
cls = LogisticRegression() cls = LogisticRegression()
kdey = KDEyML(cls) kdey = KDEyML(cls)
train, test = qp.datasets.fetch_UCIMulticlassDataset('waveform-v1', standardize=True).train_test train, test = qp.datasets.fetch_UCIMulticlassDataset('dry-bean', standardize=True).train_test
# train, test = qp.datasets.fetch_UCIMulticlassDataset('phishing', standardize=True).train_test
with qp.util.temp_seed(2): with qp.util.temp_seed(2):
print('fitting KDEy') print('fitting KDEy')
kdey.fit(*train.Xy) kdey.fit(*train.Xy)
shifted = test.sampling(500, *[0.7, 0.1, 0.2]) # shifted = test.sampling(500, *[0.7, 0.1, 0.2])
shifted = test.sampling(500, *test.prevalence()[::-1])
prev_hat = kdey.predict(shifted.X) prev_hat = kdey.predict(shifted.X)
mae = qp.error.mae(shifted.prevalence(), prev_hat) mae = qp.error.mae(shifted.prevalence(), prev_hat)
print(f'true_prev={strprev(shifted.prevalence())}, prev_hat={strprev(prev_hat)}, {mae=:.4f}') print(f'true_prev={strprev(shifted.prevalence())}, prev_hat={strprev(prev_hat)}, {mae=:.4f}')
@ -97,7 +98,10 @@ if __name__ == '__main__':
samples = bayesian(kdes, shifted.X, h, init=None, MAX_ITER=5_000, warmup=1_000) samples = bayesian(kdes, shifted.X, h, init=None, MAX_ITER=5_000, warmup=1_000)
print(f'mean posterior {strprev(samples.mean(axis=0))}') print(f'mean posterior {strprev(samples.mean(axis=0))}')
conf_interval = ConfidenceIntervals(samples, confidence_level=0.95)
print()
if train.n_classes == 3:
plot_prev_points(samples, true_prev=shifted.prevalence(), point_estim=prev_hat, train_prev=train.prevalence()) plot_prev_points(samples, true_prev=shifted.prevalence(), point_estim=prev_hat, train_prev=train.prevalence())
# plot_prev_points_matplot(samples) # plot_prev_points_matplot(samples)