1
0
Fork 0

everything working; need to clean prints though

This commit is contained in:
Alejandro Moreo Fernandez 2023-10-17 18:11:25 +02:00
parent 7593fde2e0
commit dca955819c
2 changed files with 8 additions and 1 deletions

View File

@ -8,7 +8,7 @@ from distribution_matching.method_dirichlety import DIRy
from sklearn.linear_model import LogisticRegression
from method_kdey_closed_efficient import KDEyclosed_efficient
METHODS = ['KDEy-closed++', 'KDEy-closed+', 'KDEy-closed', 'ACC', 'PACC', 'HDy-OvA', 'DIR', 'DM', 'KDEy-DMhd3', 'EMQ', 'KDEy-ML'] #, 'KDEy-DMhd2'] #, 'KDEy-DMhd2', 'DM-HD'] 'KDEy-DMjs', 'KDEy-DM', 'KDEy-ML+', 'KDEy-DMhd3+',
METHODS = ['ACC', 'PACC', 'HDy-OvA', 'DIR', 'DM', 'KDEy-DMhd3', 'KDEy-closed++', 'EMQ', 'KDEy-ML'] #, 'KDEy-DMhd2'] #, 'KDEy-DMhd2', 'DM-HD'] 'KDEy-DMjs', 'KDEy-DM', 'KDEy-ML+', 'KDEy-DMhd3+',
BIN_METHODS = [x.replace('-OvA', '') for x in METHODS]

View File

@ -61,6 +61,8 @@ class KDEyclosed_efficient_corr(AggregativeProbabilisticQuantifier):
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
)
print('training over')
assert all(sorted(np.unique(y)) == np.arange(data.n_classes)), \
'label name gaps not allowed in current implementation'
@ -94,11 +96,14 @@ class KDEyclosed_efficient_corr(AggregativeProbabilisticQuantifier):
self.tr_tr_sums = tr_tr_sums
self.counts_inv = counts_inv
print('fit over')
return self
def aggregate(self, posteriors: np.ndarray):
# print('aggregating')
Ptr = self.Ptr
Pte = posteriors
@ -121,6 +126,8 @@ class KDEyclosed_efficient_corr(AggregativeProbabilisticQuantifier):
partB = 0.5 * np.log((alpha_l[:,np.newaxis] * tr_tr_sums * alpha_l).sum())
return partA + partB + partC
# print('starting search')
# the initial point is set as the uniform distribution
uniform_distribution = np.full(fill_value=1 / n, shape=(n,))