diff --git a/examples/bayesian_quantification.py b/examples/bayesian_quantification.py index fda97cd..3bca084 100644 --- a/examples/bayesian_quantification.py +++ b/examples/bayesian_quantification.py @@ -35,6 +35,7 @@ FIGURE_PATH = "bayesian_quantification.pdf" @dataclass class SimulatedData: + """Auxiliary class to keep the training and test data sets.""" n_classes: int X_train: np.ndarray Y_train: np.ndarray @@ -44,13 +45,16 @@ class SimulatedData: def simulate_data(rng) -> SimulatedData: """Generates a simulated data set with three classes.""" - cov = np.eye(2) + # Number of examples of each class in both data sets n_train = [400, 400, 400] n_test = [40, 25, 15] + # Mean vectors and shared covariance of P(X|Y) distributions mus = [np.zeros(2), np.array([1, 1.5]), np.array([1.5, 1])] + cov = np.eye(2) + # Generate the features accordingly X_train = np.concatenate([ rng.multivariate_normal(mus[i], cov, size=n_train[i]) for i in range(3) @@ -95,6 +99,8 @@ def plot_simulated_data(axs, data: SimulatedData) -> None: ax.set_aspect("equal") ax.set_xlim(*xlim) ax.set_ylim(*ylim) + ax.set_xticks([]) + ax.set_yticks([]) ax = axs[0] ax.set_title("Training set") @@ -110,10 +116,14 @@ def plot_simulated_data(axs, data: SimulatedData) -> None: ax.set_title("Test set\n(as observed)") ax.scatter(data.X_test[:, 0], data.X_test[:, 1], c="C5", s=3, rasterized=True) + def get_random_forest() -> RandomForestClassifier: + """An auxiliary factory method to generate a random forest.""" return RandomForestClassifier(n_estimators=10, random_state=5) + def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None: + """Fits Bayesian quantification and plots posterior mean as well as individual samples""" quantifier = BayesianCC(classifier=get_random_forest()) quantifier.fit(training) @@ -129,10 +139,12 @@ def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledColle def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarray) -> None: + """Auxiliary method for running ACC and PACC.""" estimator = estimator_class(get_random_forest()) estimator.fit(training) return estimator.quantify(test) + def train_and_plot_acc(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None: estimate = _get_estimate(ACC, training, test) ax.plot(np.arange(n_classes), estimate, c="darkblue", linewidth=2, linestyle=":", label="ACC") @@ -144,6 +156,7 @@ def train_and_plot_pacc(ax: plt.Axes, training: LabelledCollection, test: np.nda def plot_true_proportions(ax: plt.Axes, test_labels: np.ndarray, n_classes: int) -> None: + """Plots the true proportions.""" counts = np.bincount(test_labels, minlength=n_classes) proportion = counts / counts.sum()