unifying n_bins in PQ and DMy
This commit is contained in:
parent
db49cd31be
commit
6db659e3c4
|
|
@ -624,7 +624,7 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
classifier: BaseEstimator=None,
|
||||
fit_classifier=True,
|
||||
val_split: int = 5,
|
||||
n_bins: int = 4,
|
||||
nbins: int = 4,
|
||||
fixed_bins: bool = False,
|
||||
num_warmup: int = 500,
|
||||
num_samples: int = 1_000,
|
||||
|
|
@ -642,7 +642,7 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
"Run `$ pip install quapy[bayes]` to install them.")
|
||||
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.n_bins = n_bins
|
||||
self.nbins = nbins
|
||||
self.fixed_bins = fixed_bins
|
||||
self.num_warmup = num_warmup
|
||||
self.num_samples = num_samples
|
||||
|
|
@ -657,10 +657,10 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
# Compute bin limits
|
||||
if self.fixed_bins:
|
||||
# Uniform bins in [0,1]
|
||||
self.bin_limits = np.linspace(0, 1, self.n_bins + 1)
|
||||
self.bin_limits = np.linspace(0, 1, self.nbins + 1)
|
||||
else:
|
||||
# Quantile bins
|
||||
self.bin_limits = np.quantile(y_pred, np.linspace(0, 1, self.n_bins + 1))
|
||||
self.bin_limits = np.quantile(y_pred, np.linspace(0, 1, self.nbins + 1))
|
||||
|
||||
# Assign each prediction to a bin
|
||||
bin_indices = np.digitize(y_pred, self.bin_limits[1:-1], right=True)
|
||||
|
|
@ -670,14 +670,14 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
neg_mask = ~pos_mask
|
||||
|
||||
# Count positives and negatives per bin
|
||||
self.pos_hist = np.bincount(bin_indices[pos_mask], minlength=self.n_bins)
|
||||
self.neg_hist = np.bincount(bin_indices[neg_mask], minlength=self.n_bins)
|
||||
self.pos_hist = np.bincount(bin_indices[pos_mask], minlength=self.nbins)
|
||||
self.neg_hist = np.bincount(bin_indices[neg_mask], minlength=self.nbins)
|
||||
|
||||
def aggregate(self, classif_predictions):
|
||||
Px_test = classif_predictions[:, self.pos_label]
|
||||
test_hist, _ = np.histogram(Px_test, bins=self.bin_limits)
|
||||
prevs = _bayesian.pq_stan(
|
||||
self.stan_code, self.n_bins, self.pos_hist, self.neg_hist, test_hist,
|
||||
self.stan_code, self.nbins, self.pos_hist, self.neg_hist, test_hist,
|
||||
self.num_samples, self.num_warmup, self.stan_seed
|
||||
).flatten()
|
||||
self.prev_distribution = np.vstack([1-prevs, prevs]).T
|
||||
|
|
|
|||
Loading…
Reference in New Issue