From dca955819c7a56895c95dcf33a1b805705449920 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Tue, 17 Oct 2023 18:11:25 +0200 Subject: [PATCH] everything working; need to clean prints though --- distribution_matching/commons.py | 2 +- .../method_kdey_closed_efficient_correct.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/distribution_matching/commons.py b/distribution_matching/commons.py index 5a8ee1d..2d9ba68 100644 --- a/distribution_matching/commons.py +++ b/distribution_matching/commons.py @@ -8,7 +8,7 @@ from distribution_matching.method_dirichlety import DIRy from sklearn.linear_model import LogisticRegression from method_kdey_closed_efficient import KDEyclosed_efficient -METHODS = ['KDEy-closed++', 'KDEy-closed+', 'KDEy-closed', 'ACC', 'PACC', 'HDy-OvA', 'DIR', 'DM', 'KDEy-DMhd3', 'EMQ', 'KDEy-ML'] #, 'KDEy-DMhd2'] #, 'KDEy-DMhd2', 'DM-HD'] 'KDEy-DMjs', 'KDEy-DM', 'KDEy-ML+', 'KDEy-DMhd3+', +METHODS = ['ACC', 'PACC', 'HDy-OvA', 'DIR', 'DM', 'KDEy-DMhd3', 'KDEy-closed++', 'EMQ', 'KDEy-ML'] #, 'KDEy-DMhd2'] #, 'KDEy-DMhd2', 'DM-HD'] 'KDEy-DMjs', 'KDEy-DM', 'KDEy-ML+', 'KDEy-DMhd3+', BIN_METHODS = [x.replace('-OvA', '') for x in METHODS] diff --git a/distribution_matching/method_kdey_closed_efficient_correct.py b/distribution_matching/method_kdey_closed_efficient_correct.py index 6dbe886..bb14ee7 100644 --- a/distribution_matching/method_kdey_closed_efficient_correct.py +++ b/distribution_matching/method_kdey_closed_efficient_correct.py @@ -61,6 +61,8 @@ class KDEyclosed_efficient_corr(AggregativeProbabilisticQuantifier): data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs ) + print('training over') + assert all(sorted(np.unique(y)) == np.arange(data.n_classes)), \ 'label name gaps not allowed in current implementation' @@ -94,11 +96,14 @@ class KDEyclosed_efficient_corr(AggregativeProbabilisticQuantifier): self.tr_tr_sums = tr_tr_sums self.counts_inv = counts_inv + print('fit over') + return self def aggregate(self, posteriors: np.ndarray): + # print('aggregating') Ptr = self.Ptr Pte = posteriors @@ -121,6 +126,8 @@ class KDEyclosed_efficient_corr(AggregativeProbabilisticQuantifier): partB = 0.5 * np.log((alpha_l[:,np.newaxis] * tr_tr_sums * alpha_l).sum()) return partA + partB + partC + # print('starting search') + # the initial point is set as the uniform distribution uniform_distribution = np.full(fill_value=1 / n, shape=(n,))