import numpy as np import matplotlib.pyplot as plt from scipy.stats import gaussian_kde def plot_prev_points(prevs, true_prev): def cartesian(p): dim = p.shape[-1] p = p.reshape(-1,dim) x = p[:, 1] + p[:, 2] * 0.5 y = p[:, 2] * np.sqrt(3) / 2 return x, y # simplex coordinates v1 = np.array([0, 0]) v2 = np.array([1, 0]) v3 = np.array([0.5, np.sqrt(3)/2]) # transform (a,b,c) -> Cartesian coordinates x, y = cartesian(prevs) # Plot fig, ax = plt.subplots(figsize=(6, 6)) ax.scatter(x, y, s=50, alpha=0.05, edgecolors='none') ax.scatter(*cartesian(true_prev), s=5, alpha=1) # edges triangle = np.array([v1, v2, v3, v1]) ax.plot(triangle[:, 0], triangle[:, 1], color='black') # vertex labels ax.text(-0.05, -0.05, "y=0", ha='right', va='top') ax.text(1.05, -0.05, "y=1", ha='left', va='top') ax.text(0.5, np.sqrt(3)/2 + 0.05, "y=2", ha='center', va='bottom') ax.set_aspect('equal') ax.axis('off') plt.show() def plot_prev_points_matplot(points): # project 2D v1 = np.array([0, 0]) v2 = np.array([1, 0]) v3 = np.array([0.5, np.sqrt(3) / 2]) x = points[:, 1] + points[:, 2] * 0.5 y = points[:, 2] * np.sqrt(3) / 2 # kde xy = np.vstack([x, y]) kde = gaussian_kde(xy) xmin, xmax = 0, 1 ymin, ymax = 0, np.sqrt(3) / 2 # grid xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j] positions = np.vstack([xx.ravel(), yy.ravel()]) zz = np.reshape(kde(positions).T, xx.shape) # mask points in simplex def in_triangle(x, y): return (y >= 0) & (y <= np.sqrt(3) * np.minimum(x, 1 - x)) mask = in_triangle(xx, yy) zz_masked = np.ma.array(zz, mask=~mask) # plot fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow( np.rot90(zz_masked), cmap=plt.cm.viridis, extent=[xmin, xmax, ymin, ymax], alpha=0.8, ) # Bordes del triángulo triangle = np.array([v1, v2, v3, v1]) ax.plot(triangle[:, 0], triangle[:, 1], color='black', lw=2) # Puntos (opcional) ax.scatter(x, y, s=5, c='white', alpha=0.3) # Etiquetas ax.text(-0.05, -0.05, "A (1,0,0)", ha='right', va='top') ax.text(1.05, -0.05, "B (0,1,0)", ha='left', va='top') ax.text(0.5, np.sqrt(3) / 2 + 0.05, "C (0,0,1)", ha='center', va='bottom') ax.set_aspect('equal') ax.axis('off') plt.show()