Improve the plot, add more comments.

This commit is contained in:
Paweł Czyż 2024-03-16 12:14:42 +01:00
parent 5cdd158fcc
commit 2db7cf20bd
1 changed files with 14 additions and 1 deletions

View File

@ -35,6 +35,7 @@ FIGURE_PATH = "bayesian_quantification.pdf"
@dataclass @dataclass
class SimulatedData: class SimulatedData:
"""Auxiliary class to keep the training and test data sets."""
n_classes: int n_classes: int
X_train: np.ndarray X_train: np.ndarray
Y_train: np.ndarray Y_train: np.ndarray
@ -44,13 +45,16 @@ class SimulatedData:
def simulate_data(rng) -> SimulatedData: def simulate_data(rng) -> SimulatedData:
"""Generates a simulated data set with three classes.""" """Generates a simulated data set with three classes."""
cov = np.eye(2)
# Number of examples of each class in both data sets
n_train = [400, 400, 400] n_train = [400, 400, 400]
n_test = [40, 25, 15] n_test = [40, 25, 15]
# Mean vectors and shared covariance of P(X|Y) distributions
mus = [np.zeros(2), np.array([1, 1.5]), np.array([1.5, 1])] mus = [np.zeros(2), np.array([1, 1.5]), np.array([1.5, 1])]
cov = np.eye(2)
# Generate the features accordingly
X_train = np.concatenate([ X_train = np.concatenate([
rng.multivariate_normal(mus[i], cov, size=n_train[i]) rng.multivariate_normal(mus[i], cov, size=n_train[i])
for i in range(3) for i in range(3)
@ -95,6 +99,8 @@ def plot_simulated_data(axs, data: SimulatedData) -> None:
ax.set_aspect("equal") ax.set_aspect("equal")
ax.set_xlim(*xlim) ax.set_xlim(*xlim)
ax.set_ylim(*ylim) ax.set_ylim(*ylim)
ax.set_xticks([])
ax.set_yticks([])
ax = axs[0] ax = axs[0]
ax.set_title("Training set") ax.set_title("Training set")
@ -110,10 +116,14 @@ def plot_simulated_data(axs, data: SimulatedData) -> None:
ax.set_title("Test set\n(as observed)") ax.set_title("Test set\n(as observed)")
ax.scatter(data.X_test[:, 0], data.X_test[:, 1], c="C5", s=3, rasterized=True) ax.scatter(data.X_test[:, 0], data.X_test[:, 1], c="C5", s=3, rasterized=True)
def get_random_forest() -> RandomForestClassifier: def get_random_forest() -> RandomForestClassifier:
"""An auxiliary factory method to generate a random forest."""
return RandomForestClassifier(n_estimators=10, random_state=5) return RandomForestClassifier(n_estimators=10, random_state=5)
def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None: def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None:
"""Fits Bayesian quantification and plots posterior mean as well as individual samples"""
quantifier = BayesianCC(classifier=get_random_forest()) quantifier = BayesianCC(classifier=get_random_forest())
quantifier.fit(training) quantifier.fit(training)
@ -129,10 +139,12 @@ def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledColle
def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarray) -> None: def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarray) -> None:
"""Auxiliary method for running ACC and PACC."""
estimator = estimator_class(get_random_forest()) estimator = estimator_class(get_random_forest())
estimator.fit(training) estimator.fit(training)
return estimator.quantify(test) return estimator.quantify(test)
def train_and_plot_acc(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None: def train_and_plot_acc(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None:
estimate = _get_estimate(ACC, training, test) estimate = _get_estimate(ACC, training, test)
ax.plot(np.arange(n_classes), estimate, c="darkblue", linewidth=2, linestyle=":", label="ACC") ax.plot(np.arange(n_classes), estimate, c="darkblue", linewidth=2, linestyle=":", label="ACC")
@ -144,6 +156,7 @@ def train_and_plot_pacc(ax: plt.Axes, training: LabelledCollection, test: np.nda
def plot_true_proportions(ax: plt.Axes, test_labels: np.ndarray, n_classes: int) -> None: def plot_true_proportions(ax: plt.Axes, test_labels: np.ndarray, n_classes: int) -> None:
"""Plots the true proportions."""
counts = np.bincount(test_labels, minlength=n_classes) counts = np.bincount(test_labels, minlength=n_classes)
proportion = counts / counts.sum() proportion = counts / counts.sum()