Improve the plot, add more comments.
This commit is contained in:
parent
5cdd158fcc
commit
2db7cf20bd
|
@ -35,6 +35,7 @@ FIGURE_PATH = "bayesian_quantification.pdf"
|
|||
|
||||
@dataclass
|
||||
class SimulatedData:
|
||||
"""Auxiliary class to keep the training and test data sets."""
|
||||
n_classes: int
|
||||
X_train: np.ndarray
|
||||
Y_train: np.ndarray
|
||||
|
@ -44,13 +45,16 @@ class SimulatedData:
|
|||
|
||||
def simulate_data(rng) -> SimulatedData:
|
||||
"""Generates a simulated data set with three classes."""
|
||||
cov = np.eye(2)
|
||||
|
||||
# Number of examples of each class in both data sets
|
||||
n_train = [400, 400, 400]
|
||||
n_test = [40, 25, 15]
|
||||
|
||||
# Mean vectors and shared covariance of P(X|Y) distributions
|
||||
mus = [np.zeros(2), np.array([1, 1.5]), np.array([1.5, 1])]
|
||||
cov = np.eye(2)
|
||||
|
||||
# Generate the features accordingly
|
||||
X_train = np.concatenate([
|
||||
rng.multivariate_normal(mus[i], cov, size=n_train[i])
|
||||
for i in range(3)
|
||||
|
@ -95,6 +99,8 @@ def plot_simulated_data(axs, data: SimulatedData) -> None:
|
|||
ax.set_aspect("equal")
|
||||
ax.set_xlim(*xlim)
|
||||
ax.set_ylim(*ylim)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
ax = axs[0]
|
||||
ax.set_title("Training set")
|
||||
|
@ -110,10 +116,14 @@ def plot_simulated_data(axs, data: SimulatedData) -> None:
|
|||
ax.set_title("Test set\n(as observed)")
|
||||
ax.scatter(data.X_test[:, 0], data.X_test[:, 1], c="C5", s=3, rasterized=True)
|
||||
|
||||
|
||||
def get_random_forest() -> RandomForestClassifier:
|
||||
"""An auxiliary factory method to generate a random forest."""
|
||||
return RandomForestClassifier(n_estimators=10, random_state=5)
|
||||
|
||||
|
||||
def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None:
|
||||
"""Fits Bayesian quantification and plots posterior mean as well as individual samples"""
|
||||
quantifier = BayesianCC(classifier=get_random_forest())
|
||||
quantifier.fit(training)
|
||||
|
||||
|
@ -129,10 +139,12 @@ def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledColle
|
|||
|
||||
|
||||
def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarray) -> None:
|
||||
"""Auxiliary method for running ACC and PACC."""
|
||||
estimator = estimator_class(get_random_forest())
|
||||
estimator.fit(training)
|
||||
return estimator.quantify(test)
|
||||
|
||||
|
||||
def train_and_plot_acc(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None:
|
||||
estimate = _get_estimate(ACC, training, test)
|
||||
ax.plot(np.arange(n_classes), estimate, c="darkblue", linewidth=2, linestyle=":", label="ACC")
|
||||
|
@ -144,6 +156,7 @@ def train_and_plot_pacc(ax: plt.Axes, training: LabelledCollection, test: np.nda
|
|||
|
||||
|
||||
def plot_true_proportions(ax: plt.Axes, test_labels: np.ndarray, n_classes: int) -> None:
|
||||
"""Plots the true proportions."""
|
||||
counts = np.bincount(test_labels, minlength=n_classes)
|
||||
proportion = counts / counts.sum()
|
||||
|
||||
|
|
Loading…
Reference in New Issue