merged
This commit is contained in:
parent
6388d9b549
commit
047cb9e533
|
|
@ -565,10 +565,12 @@ class BayesianCC(AggregativeCrispQuantifier, WithConfidenceABC):
|
|||
return np.asarray(samples.mean(axis=0), dtype=float)
|
||||
|
||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
if confidence_level is None:
|
||||
confidence_level = self.confidence_level
|
||||
classif_predictions = self.classify(instances)
|
||||
point_estimate = self.aggregate(classif_predictions)
|
||||
samples = self.get_prevalence_samples() # available after calling "aggregate" function
|
||||
region = WithConfidenceABC.construct_region(samples, confidence_level=self.confidence_level, method=self.region)
|
||||
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
||||
return point_estimate, region
|
||||
|
||||
|
||||
|
|
@ -606,6 +608,7 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
num_warmup: int = 500,
|
||||
num_samples: int = 1_000,
|
||||
stan_seed: int = 0,
|
||||
confidence_level: float = 0.95,
|
||||
region: str = 'intervals'):
|
||||
|
||||
if num_warmup <= 0:
|
||||
|
|
@ -622,9 +625,10 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
self.fixed_bins = fixed_bins
|
||||
self.num_warmup = num_warmup
|
||||
self.num_samples = num_samples
|
||||
self.region = region
|
||||
self.stan_seed = stan_seed
|
||||
self.stan_code = _bayesian.load_stan_file()
|
||||
self.confidence_level = confidence_level
|
||||
self.region = region
|
||||
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
y_pred = classif_predictions[:, self.pos_label]
|
||||
|
|
@ -651,16 +655,23 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
|||
def aggregate(self, classif_predictions):
|
||||
Px_test = classif_predictions[:, self.pos_label]
|
||||
test_hist, _ = np.histogram(Px_test, bins=self.bin_limits)
|
||||
self.prev_distribution = _bayesian.pq_stan(
|
||||
prevs = _bayesian.pq_stan(
|
||||
self.stan_code, self.n_bins, self.pos_hist, self.neg_hist, test_hist,
|
||||
self.num_samples, self.num_warmup, self.stan_seed
|
||||
)
|
||||
return F.as_binary_prevalence(self.prev_distribution.mean())
|
||||
).flatten()
|
||||
self.prev_distribution = np.vstack([1-prevs, prevs]).T
|
||||
return self.prev_distribution.mean(axis=0)
|
||||
|
||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
classif_predictions = self.classify(instances)
|
||||
point_estimate = self.aggregate(classif_predictions)
|
||||
def aggregate_conf(self, predictions, confidence_level=None):
|
||||
if confidence_level is None:
|
||||
confidence_level = self.confidence_level
|
||||
point_estimate = self.aggregate(predictions)
|
||||
samples = self.prev_distribution
|
||||
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
||||
return point_estimate, region
|
||||
|
||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
predictions = self.classify(instances)
|
||||
return self.aggregate_conf(predictions, confidence_level=confidence_level)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue