unifying n_bins in PQ and DMy
This commit is contained in:
parent
db49cd31be
commit
6db659e3c4
|
|
@ -624,7 +624,7 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
classifier: BaseEstimator=None,
|
classifier: BaseEstimator=None,
|
||||||
fit_classifier=True,
|
fit_classifier=True,
|
||||||
val_split: int = 5,
|
val_split: int = 5,
|
||||||
n_bins: int = 4,
|
nbins: int = 4,
|
||||||
fixed_bins: bool = False,
|
fixed_bins: bool = False,
|
||||||
num_warmup: int = 500,
|
num_warmup: int = 500,
|
||||||
num_samples: int = 1_000,
|
num_samples: int = 1_000,
|
||||||
|
|
@ -642,7 +642,7 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
"Run `$ pip install quapy[bayes]` to install them.")
|
"Run `$ pip install quapy[bayes]` to install them.")
|
||||||
|
|
||||||
super().__init__(classifier, fit_classifier, val_split)
|
super().__init__(classifier, fit_classifier, val_split)
|
||||||
self.n_bins = n_bins
|
self.nbins = nbins
|
||||||
self.fixed_bins = fixed_bins
|
self.fixed_bins = fixed_bins
|
||||||
self.num_warmup = num_warmup
|
self.num_warmup = num_warmup
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
|
|
@ -657,10 +657,10 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
# Compute bin limits
|
# Compute bin limits
|
||||||
if self.fixed_bins:
|
if self.fixed_bins:
|
||||||
# Uniform bins in [0,1]
|
# Uniform bins in [0,1]
|
||||||
self.bin_limits = np.linspace(0, 1, self.n_bins + 1)
|
self.bin_limits = np.linspace(0, 1, self.nbins + 1)
|
||||||
else:
|
else:
|
||||||
# Quantile bins
|
# Quantile bins
|
||||||
self.bin_limits = np.quantile(y_pred, np.linspace(0, 1, self.n_bins + 1))
|
self.bin_limits = np.quantile(y_pred, np.linspace(0, 1, self.nbins + 1))
|
||||||
|
|
||||||
# Assign each prediction to a bin
|
# Assign each prediction to a bin
|
||||||
bin_indices = np.digitize(y_pred, self.bin_limits[1:-1], right=True)
|
bin_indices = np.digitize(y_pred, self.bin_limits[1:-1], right=True)
|
||||||
|
|
@ -670,14 +670,14 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
neg_mask = ~pos_mask
|
neg_mask = ~pos_mask
|
||||||
|
|
||||||
# Count positives and negatives per bin
|
# Count positives and negatives per bin
|
||||||
self.pos_hist = np.bincount(bin_indices[pos_mask], minlength=self.n_bins)
|
self.pos_hist = np.bincount(bin_indices[pos_mask], minlength=self.nbins)
|
||||||
self.neg_hist = np.bincount(bin_indices[neg_mask], minlength=self.n_bins)
|
self.neg_hist = np.bincount(bin_indices[neg_mask], minlength=self.nbins)
|
||||||
|
|
||||||
def aggregate(self, classif_predictions):
|
def aggregate(self, classif_predictions):
|
||||||
Px_test = classif_predictions[:, self.pos_label]
|
Px_test = classif_predictions[:, self.pos_label]
|
||||||
test_hist, _ = np.histogram(Px_test, bins=self.bin_limits)
|
test_hist, _ = np.histogram(Px_test, bins=self.bin_limits)
|
||||||
prevs = _bayesian.pq_stan(
|
prevs = _bayesian.pq_stan(
|
||||||
self.stan_code, self.n_bins, self.pos_hist, self.neg_hist, test_hist,
|
self.stan_code, self.nbins, self.pos_hist, self.neg_hist, test_hist,
|
||||||
self.num_samples, self.num_warmup, self.stan_seed
|
self.num_samples, self.num_warmup, self.stan_seed
|
||||||
).flatten()
|
).flatten()
|
||||||
self.prev_distribution = np.vstack([1-prevs, prevs]).T
|
self.prev_distribution = np.vstack([1-prevs, prevs]).T
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue