QuaPy/BayesianKDEy/plot_simplex.py

95 lines
2.4 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
def plot_prev_points(prevs, true_prev):
def cartesian(p):
dim = p.shape[-1]
p = p.reshape(-1,dim)
x = p[:, 1] + p[:, 2] * 0.5
y = p[:, 2] * np.sqrt(3) / 2
return x, y
# simplex coordinates
v1 = np.array([0, 0])
v2 = np.array([1, 0])
v3 = np.array([0.5, np.sqrt(3)/2])
# transform (a,b,c) -> Cartesian coordinates
x, y = cartesian(prevs)
# Plot
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(x, y, s=50, alpha=0.05, edgecolors='none')
ax.scatter(*cartesian(true_prev), s=5, alpha=1)
# edges
triangle = np.array([v1, v2, v3, v1])
ax.plot(triangle[:, 0], triangle[:, 1], color='black')
# vertex labels
ax.text(-0.05, -0.05, "y=0", ha='right', va='top')
ax.text(1.05, -0.05, "y=1", ha='left', va='top')
ax.text(0.5, np.sqrt(3)/2 + 0.05, "y=2", ha='center', va='bottom')
ax.set_aspect('equal')
ax.axis('off')
plt.show()
def plot_prev_points_matplot(points):
# project 2D
v1 = np.array([0, 0])
v2 = np.array([1, 0])
v3 = np.array([0.5, np.sqrt(3) / 2])
x = points[:, 1] + points[:, 2] * 0.5
y = points[:, 2] * np.sqrt(3) / 2
# kde
xy = np.vstack([x, y])
kde = gaussian_kde(xy)
xmin, xmax = 0, 1
ymin, ymax = 0, np.sqrt(3) / 2
# grid
xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
positions = np.vstack([xx.ravel(), yy.ravel()])
zz = np.reshape(kde(positions).T, xx.shape)
# mask points in simplex
def in_triangle(x, y):
return (y >= 0) & (y <= np.sqrt(3) * np.minimum(x, 1 - x))
mask = in_triangle(xx, yy)
zz_masked = np.ma.array(zz, mask=~mask)
# plot
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(
np.rot90(zz_masked),
cmap=plt.cm.viridis,
extent=[xmin, xmax, ymin, ymax],
alpha=0.8,
)
# Bordes del triángulo
triangle = np.array([v1, v2, v3, v1])
ax.plot(triangle[:, 0], triangle[:, 1], color='black', lw=2)
# Puntos (opcional)
ax.scatter(x, y, s=5, c='white', alpha=0.3)
# Etiquetas
ax.text(-0.05, -0.05, "A (1,0,0)", ha='right', va='top')
ax.text(1.05, -0.05, "B (0,1,0)", ha='left', va='top')
ax.text(0.5, np.sqrt(3) / 2 + 0.05, "C (0,0,1)", ha='center', va='bottom')
ax.set_aspect('equal')
ax.axis('off')
plt.show()