113 lines
3.3 KiB
Python
113 lines
3.3 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from scipy.stats import gaussian_kde
|
|
|
|
|
|
def plot_prev_points(prevs, true_prev, point_estim, train_prev):
|
|
plt.rcParams.update({
|
|
'font.size': 10, # tamaño base de todo el texto
|
|
'axes.titlesize': 12, # título del eje
|
|
'axes.labelsize': 10, # etiquetas de ejes
|
|
'xtick.labelsize': 8, # etiquetas de ticks
|
|
'ytick.labelsize': 8,
|
|
'legend.fontsize': 9, # leyenda
|
|
})
|
|
|
|
def cartesian(p):
|
|
dim = p.shape[-1]
|
|
p = p.reshape(-1,dim)
|
|
x = p[:, 1] + p[:, 2] * 0.5
|
|
y = p[:, 2] * np.sqrt(3) / 2
|
|
return x, y
|
|
|
|
# simplex coordinates
|
|
v1 = np.array([0, 0])
|
|
v2 = np.array([1, 0])
|
|
v3 = np.array([0.5, np.sqrt(3)/2])
|
|
|
|
# Plot
|
|
fig, ax = plt.subplots(figsize=(6, 6))
|
|
ax.scatter(*cartesian(prevs), s=10, alpha=0.5, edgecolors='none', label='samples')
|
|
ax.scatter(*cartesian(prevs.mean(axis=0)), s=10, alpha=1, label='sample-mean', edgecolors='black')
|
|
ax.scatter(*cartesian(true_prev), s=10, alpha=1, label='true-prev', edgecolors='black')
|
|
ax.scatter(*cartesian(point_estim), s=10, alpha=1, label='KDEy-estim', edgecolors='black')
|
|
ax.scatter(*cartesian(train_prev), s=10, alpha=1, label='train-prev', edgecolors='black')
|
|
|
|
# edges
|
|
triangle = np.array([v1, v2, v3, v1])
|
|
ax.plot(triangle[:, 0], triangle[:, 1], color='black')
|
|
|
|
# vertex labels
|
|
ax.text(-0.05, -0.05, "y=0", ha='right', va='top')
|
|
ax.text(1.05, -0.05, "y=1", ha='left', va='top')
|
|
ax.text(0.5, np.sqrt(3)/2 + 0.05, "y=2", ha='center', va='bottom')
|
|
|
|
ax.set_aspect('equal')
|
|
ax.axis('off')
|
|
plt.legend(
|
|
loc='center left',
|
|
bbox_to_anchor=(1.05, 0.5),
|
|
# ncol=3,
|
|
# frameon=False
|
|
)
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
|
|
def plot_prev_points_matplot(points):
|
|
|
|
# project 2D
|
|
v1 = np.array([0, 0])
|
|
v2 = np.array([1, 0])
|
|
v3 = np.array([0.5, np.sqrt(3) / 2])
|
|
x = points[:, 1] + points[:, 2] * 0.5
|
|
y = points[:, 2] * np.sqrt(3) / 2
|
|
|
|
# kde
|
|
xy = np.vstack([x, y])
|
|
kde = gaussian_kde(xy, bw_method=0.25)
|
|
xmin, xmax = 0, 1
|
|
ymin, ymax = 0, np.sqrt(3) / 2
|
|
|
|
# grid
|
|
xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
|
|
positions = np.vstack([xx.ravel(), yy.ravel()])
|
|
zz = np.reshape(kde(positions).T, xx.shape)
|
|
|
|
# mask points in simplex
|
|
def in_triangle(x, y):
|
|
return (y >= 0) & (y <= np.sqrt(3) * np.minimum(x, 1 - x))
|
|
|
|
mask = in_triangle(xx, yy)
|
|
zz_masked = np.ma.array(zz, mask=~mask)
|
|
|
|
# plot
|
|
fig, ax = plt.subplots(figsize=(6, 6))
|
|
ax.imshow(
|
|
np.rot90(zz_masked),
|
|
cmap=plt.cm.viridis,
|
|
extent=[xmin, xmax, ymin, ymax],
|
|
alpha=0.8,
|
|
)
|
|
|
|
# Bordes del triángulo
|
|
triangle = np.array([v1, v2, v3, v1])
|
|
ax.plot(triangle[:, 0], triangle[:, 1], color='black', lw=2)
|
|
|
|
# Puntos (opcional)
|
|
ax.scatter(x, y, s=5, c='white', alpha=0.3)
|
|
|
|
# Etiquetas
|
|
ax.text(-0.05, -0.05, "A (1,0,0)", ha='right', va='top')
|
|
ax.text(1.05, -0.05, "B (0,1,0)", ha='left', va='top')
|
|
ax.text(0.5, np.sqrt(3) / 2 + 0.05, "C (0,0,1)", ha='center', va='bottom')
|
|
|
|
ax.set_aspect('equal')
|
|
ax.axis('off')
|
|
plt.show()
|
|
|
|
if __name__ == '__main__':
|
|
n = 1000
|
|
points = np.random.dirichlet([2, 3, 4], size=n)
|
|
plot_prev_points_matplot(points)
|