Improve the plot, add more comments.
This commit is contained in:
parent
5cdd158fcc
commit
2db7cf20bd
|
@ -35,6 +35,7 @@ FIGURE_PATH = "bayesian_quantification.pdf"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SimulatedData:
|
class SimulatedData:
|
||||||
|
"""Auxiliary class to keep the training and test data sets."""
|
||||||
n_classes: int
|
n_classes: int
|
||||||
X_train: np.ndarray
|
X_train: np.ndarray
|
||||||
Y_train: np.ndarray
|
Y_train: np.ndarray
|
||||||
|
@ -44,13 +45,16 @@ class SimulatedData:
|
||||||
|
|
||||||
def simulate_data(rng) -> SimulatedData:
|
def simulate_data(rng) -> SimulatedData:
|
||||||
"""Generates a simulated data set with three classes."""
|
"""Generates a simulated data set with three classes."""
|
||||||
cov = np.eye(2)
|
|
||||||
|
|
||||||
|
# Number of examples of each class in both data sets
|
||||||
n_train = [400, 400, 400]
|
n_train = [400, 400, 400]
|
||||||
n_test = [40, 25, 15]
|
n_test = [40, 25, 15]
|
||||||
|
|
||||||
|
# Mean vectors and shared covariance of P(X|Y) distributions
|
||||||
mus = [np.zeros(2), np.array([1, 1.5]), np.array([1.5, 1])]
|
mus = [np.zeros(2), np.array([1, 1.5]), np.array([1.5, 1])]
|
||||||
|
cov = np.eye(2)
|
||||||
|
|
||||||
|
# Generate the features accordingly
|
||||||
X_train = np.concatenate([
|
X_train = np.concatenate([
|
||||||
rng.multivariate_normal(mus[i], cov, size=n_train[i])
|
rng.multivariate_normal(mus[i], cov, size=n_train[i])
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
|
@ -95,6 +99,8 @@ def plot_simulated_data(axs, data: SimulatedData) -> None:
|
||||||
ax.set_aspect("equal")
|
ax.set_aspect("equal")
|
||||||
ax.set_xlim(*xlim)
|
ax.set_xlim(*xlim)
|
||||||
ax.set_ylim(*ylim)
|
ax.set_ylim(*ylim)
|
||||||
|
ax.set_xticks([])
|
||||||
|
ax.set_yticks([])
|
||||||
|
|
||||||
ax = axs[0]
|
ax = axs[0]
|
||||||
ax.set_title("Training set")
|
ax.set_title("Training set")
|
||||||
|
@ -110,10 +116,14 @@ def plot_simulated_data(axs, data: SimulatedData) -> None:
|
||||||
ax.set_title("Test set\n(as observed)")
|
ax.set_title("Test set\n(as observed)")
|
||||||
ax.scatter(data.X_test[:, 0], data.X_test[:, 1], c="C5", s=3, rasterized=True)
|
ax.scatter(data.X_test[:, 0], data.X_test[:, 1], c="C5", s=3, rasterized=True)
|
||||||
|
|
||||||
|
|
||||||
def get_random_forest() -> RandomForestClassifier:
|
def get_random_forest() -> RandomForestClassifier:
|
||||||
|
"""An auxiliary factory method to generate a random forest."""
|
||||||
return RandomForestClassifier(n_estimators=10, random_state=5)
|
return RandomForestClassifier(n_estimators=10, random_state=5)
|
||||||
|
|
||||||
|
|
||||||
def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None:
|
def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None:
|
||||||
|
"""Fits Bayesian quantification and plots posterior mean as well as individual samples"""
|
||||||
quantifier = BayesianCC(classifier=get_random_forest())
|
quantifier = BayesianCC(classifier=get_random_forest())
|
||||||
quantifier.fit(training)
|
quantifier.fit(training)
|
||||||
|
|
||||||
|
@ -129,10 +139,12 @@ def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledColle
|
||||||
|
|
||||||
|
|
||||||
def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarray) -> None:
|
def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarray) -> None:
|
||||||
|
"""Auxiliary method for running ACC and PACC."""
|
||||||
estimator = estimator_class(get_random_forest())
|
estimator = estimator_class(get_random_forest())
|
||||||
estimator.fit(training)
|
estimator.fit(training)
|
||||||
return estimator.quantify(test)
|
return estimator.quantify(test)
|
||||||
|
|
||||||
|
|
||||||
def train_and_plot_acc(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None:
|
def train_and_plot_acc(ax: plt.Axes, training: LabelledCollection, test: np.ndarray, n_classes: int) -> None:
|
||||||
estimate = _get_estimate(ACC, training, test)
|
estimate = _get_estimate(ACC, training, test)
|
||||||
ax.plot(np.arange(n_classes), estimate, c="darkblue", linewidth=2, linestyle=":", label="ACC")
|
ax.plot(np.arange(n_classes), estimate, c="darkblue", linewidth=2, linestyle=":", label="ACC")
|
||||||
|
@ -144,6 +156,7 @@ def train_and_plot_pacc(ax: plt.Axes, training: LabelledCollection, test: np.nda
|
||||||
|
|
||||||
|
|
||||||
def plot_true_proportions(ax: plt.Axes, test_labels: np.ndarray, n_classes: int) -> None:
|
def plot_true_proportions(ax: plt.Axes, test_labels: np.ndarray, n_classes: int) -> None:
|
||||||
|
"""Plots the true proportions."""
|
||||||
counts = np.bincount(test_labels, minlength=n_classes)
|
counts = np.bincount(test_labels, minlength=n_classes)
|
||||||
proportion = counts / counts.sum()
|
proportion = counts / counts.sum()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue