adding all plots
This commit is contained in:
parent
4e6014c0f2
commit
e9536af69e
|
|
@ -0,0 +1,73 @@
|
|||
|
||||
import math
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import train_test_split, cross_val_predict
|
||||
from sklearn.neighbors import KernelDensity
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
scale = 100
|
||||
|
||||
|
||||
import quapy as qp
|
||||
|
||||
dataset='wa'
|
||||
data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True, for_model_selection=False)
|
||||
|
||||
X, y = data.training.Xy
|
||||
|
||||
cls = LogisticRegression(C=0.0001, random_state=0)
|
||||
|
||||
|
||||
posteriors = cross_val_predict(cls, X=X, y=y, method='predict_proba', n_jobs=-1, cv=3)
|
||||
|
||||
cls.fit(X, y)
|
||||
|
||||
Xte, yte = data.test.sampling(1000, *[0.7, 0.2, 0.1], random_state=0).Xy
|
||||
|
||||
post_c1 = posteriors[y==0]
|
||||
post_c2 = posteriors[y==1]
|
||||
post_c3 = posteriors[y==2]
|
||||
|
||||
|
||||
print(len(post_c1))
|
||||
print(len(post_c2))
|
||||
print(len(post_c3))
|
||||
|
||||
post_test = cls.predict_proba(Xte)
|
||||
|
||||
alpha = qp.functional.prevalence_from_labels(yte, classes=[0, 1, 2])
|
||||
|
||||
|
||||
nbins = 20
|
||||
|
||||
plt.rcParams.update({'font.size': 7})
|
||||
|
||||
fig = plt.figure()
|
||||
positions = np.asarray([2,1,0])
|
||||
colors = ['r', 'g', 'b']
|
||||
|
||||
for i, post_set in enumerate([post_c1, post_c2, post_c3, post_test]):
|
||||
ax = fig.add_subplot(141+i, projection='3d')
|
||||
for post, c, z in zip(post_set.T, colors, positions):
|
||||
|
||||
hist, bins = np.histogram(post, bins=nbins, density=True, range=[0,1])
|
||||
xs = (bins[:-1] + bins[1:])/2
|
||||
|
||||
ax.bar(xs, hist, width=1/nbins, zs=z, zdir='y', color=c, ec=c, alpha=0.6)
|
||||
|
||||
ax.yaxis.set_ticks(positions)
|
||||
ax.yaxis.set_ticklabels(['$y=1$', '$y=2$', '$y=3$'])
|
||||
ax.xaxis.set_ticks([])
|
||||
ax.xaxis.set_ticklabels([], minor=True)
|
||||
ax.zaxis.set_ticks([])
|
||||
ax.zaxis.set_ticklabels([], minor=True)
|
||||
|
||||
|
||||
# plt.figure(figsize=(10,6))
|
||||
# plt.show()
|
||||
plt.savefig(f'./plots_ieee/multiclasshistograms_{dataset}.pdf')
|
||||
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ def prepare_xy_date_blocks(df, freq="M"):
|
|||
return X, y, date, idx2date
|
||||
|
||||
|
||||
def prepare_labelled_collections():
|
||||
def prepare_labelled_collections(filter_neutral):
|
||||
# loads and prepares the Twitter US Arlines Sentiment dataset (from Kaggle)
|
||||
# returns a labelled collection for the training data (day 0 and 1), and a list of the
|
||||
# test sets (days 2 to 8) and the time limits for each test period
|
||||
|
|
@ -80,12 +80,14 @@ def prepare_labelled_collections():
|
|||
X, y, date, idx2date = prepare_xy_date_blocks(df, freq="D")
|
||||
|
||||
# binarize
|
||||
|
||||
keep_idx = (y!='neutral')
|
||||
X = X[keep_idx]
|
||||
y = y[keep_idx]
|
||||
date = date[keep_idx]
|
||||
y[y != 'negative'] = 1
|
||||
if filter_neutral:
|
||||
keep_idx = (y!='neutral')
|
||||
X = X[keep_idx]
|
||||
y = y[keep_idx]
|
||||
date = date[keep_idx]
|
||||
else:
|
||||
y[y == 'neutral'] = 2
|
||||
y[y == 'positive'] = 1
|
||||
y[y == 'negative'] = 0
|
||||
y = y.astype(int)
|
||||
|
||||
|
|
@ -248,7 +250,7 @@ class HFTextClassifier(BaseEstimator, ClassifierMixin):
|
|||
USE_LOGISTIC_REGRESSION = True
|
||||
|
||||
if USE_LOGISTIC_REGRESSION:
|
||||
new_classifier = lambda:LR()
|
||||
new_classifier = lambda:LR(C=1)
|
||||
to_fit = True
|
||||
else:
|
||||
pretrained = HFTextClassifier()
|
||||
|
|
@ -266,7 +268,7 @@ def methods():
|
|||
yield 'KDEy', KDEyML(new_classifier(), fit_classifier=to_fit)
|
||||
|
||||
|
||||
train, tests, test_init = prepare_labelled_collections()
|
||||
train, tests, test_init = prepare_labelled_collections(filter_neutral=True)
|
||||
|
||||
if USE_LOGISTIC_REGRESSION:
|
||||
# vectorize text for logistic regression
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ plt.title('')
|
|||
fig.set(yticklabels=[])
|
||||
fig.set(ylabel=None)
|
||||
setframe()
|
||||
fig.get_figure().savefig('plots_cacm/training.pdf')
|
||||
fig.get_figure().savefig('plots_ieee/training.pdf')
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
||||
|
|
@ -68,6 +68,6 @@ plt.title('')
|
|||
fig.set(yticklabels=[])
|
||||
fig.set(ylabel=None)
|
||||
setframe()
|
||||
fig.get_figure().savefig('plots_cacm/test.pdf')
|
||||
fig.get_figure().savefig('plots_ieee/test.pdf')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from sklearn.neighbors import KernelDensity
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from data import LabelledCollection
|
||||
|
||||
scale = 100
|
||||
|
||||
|
|
@ -57,6 +56,6 @@ ax.zaxis.set_ticklabels([], minor=True)
|
|||
|
||||
#plt.figure(figsize=(10,6))
|
||||
#plt.show()
|
||||
plt.savefig('./histograms3d_CACM2023.pdf')
|
||||
plt.savefig('./histograms3d_IEEE2025.pdf')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,154 @@
|
|||
import math
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.neighbors import KernelDensity
|
||||
|
||||
from data import LabelledCollection
|
||||
|
||||
scale = 10
|
||||
|
||||
|
||||
# con ternary (una lib de matplotlib) salen bien pero no puedo crear contornos, o no se
|
||||
# con plotly salen los contornos bien, pero es un poco un jaleo porque utiliza el navegador...
|
||||
|
||||
def plot_simplex_(ax, density, title='', fontsize=9, points=None):
|
||||
import ternary
|
||||
|
||||
tax = ternary.TernaryAxesSubplot(ax=ax, scale=scale)
|
||||
tax.heatmapf(density, boundary=True, style="triangular", colorbar=False, cmap='viridis') #cmap='magma')
|
||||
tax.boundary(linewidth=1.0)
|
||||
corner_fontsize = 5*fontsize//6
|
||||
tax.right_corner_label("$y=3$", fontsize=corner_fontsize)
|
||||
tax.top_corner_label("$y=2$", fontsize=corner_fontsize)
|
||||
tax.left_corner_label("$y=1$", fontsize=corner_fontsize)
|
||||
if title:
|
||||
tax.set_title(title, loc='center', y=-0.11, fontsize=fontsize)
|
||||
if points is not None:
|
||||
tax.scatter(points*scale, marker='o', color='w', alpha=0.25, zorder=10)
|
||||
tax.get_axes().axis('off')
|
||||
tax.clear_matplotlib_ticks()
|
||||
|
||||
return tax
|
||||
|
||||
|
||||
def plot_simplex(ax, coord, kde_scores, title='', fontsize=11, points=None, savepath=None):
|
||||
import plotly.figure_factory as ff
|
||||
|
||||
tax = ff.create_ternary_contour(coord.T, kde_scores, pole_labels=['y=1', 'y=2', 'y=3'],
|
||||
interp_mode='cartesian',
|
||||
ncontours=20,
|
||||
colorscale='Viridis',
|
||||
showscale=True,
|
||||
title=title)
|
||||
if savepath is None:
|
||||
tax.show()
|
||||
else:
|
||||
tax.write_image(savepath)
|
||||
return tax
|
||||
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
def plot_3class_problem(post_c1, post_c2, post_c3, post_test, alpha, bandwidth):
|
||||
import ternary
|
||||
|
||||
post_c1 = np.flip(post_c1, axis=1)
|
||||
post_c2 = np.flip(post_c2, axis=1)
|
||||
post_c3 = np.flip(post_c3, axis=1)
|
||||
post_test = np.flip(post_test, axis=1)
|
||||
|
||||
size_=10
|
||||
fig = ternary.plt.figure(figsize=(4*size_, 1*size_))
|
||||
fig.tight_layout()
|
||||
ax1 = fig.add_subplot(1, 4, 1)
|
||||
divider = make_axes_locatable(ax1)
|
||||
ax2 = fig.add_subplot(1, 4, 2)
|
||||
divider = make_axes_locatable(ax2)
|
||||
ax3 = fig.add_subplot(1, 4, 3)
|
||||
divider = make_axes_locatable(ax3)
|
||||
ax4 = fig.add_subplot(1, 4, 4)
|
||||
divider = make_axes_locatable(ax4)
|
||||
|
||||
kde1 = KernelDensity(bandwidth=bandwidth).fit(post_c1)
|
||||
kde2 = KernelDensity(bandwidth=bandwidth).fit(post_c2)
|
||||
kde3 = KernelDensity(bandwidth=bandwidth).fit(post_c3)
|
||||
|
||||
#post_c1 = np.concatenate([post_c1, np.eye(3, dtype=float)])
|
||||
#post_c2 = np.concatenate([post_c2, np.eye(3, dtype=float)])
|
||||
#post_c3 = np.concatenate([post_c3, np.eye(3, dtype=float)])
|
||||
|
||||
#plot_simplex_(ax1, lambda x:0, title='$f_1(\mathbf{x})=p(s(\mathbf{x})|y=1)$')
|
||||
#plot_simplex_(ax2, lambda x:0, title='$f_1(\mathbf{x})=p(s(\mathbf{x})|y=1)$')
|
||||
#plot_simplex_(ax3, lambda x:0, title='$f_1(\mathbf{x})=p(s(\mathbf{x})|y=1)$')
|
||||
def density(kde):
|
||||
def d(p):
|
||||
return np.exp(kde([p])).item()
|
||||
return d
|
||||
|
||||
plot_simplex_(ax1, density(kde1.score_samples), title='$p_1$')
|
||||
plot_simplex_(ax2, density(kde2.score_samples), title='$p_2$')
|
||||
plot_simplex_(ax3, density(kde3.score_samples), title='$p_3$')
|
||||
#plot_simplex(ax1, post_c1, np.exp(kde1.score_samples(post_c1)), title='$f_1(\mathbf{x})=p(s(\mathbf{x})|y=1)$') #, savepath='figure/y1.png')
|
||||
#plot_simplex(ax2, post_c2, np.exp(kde2.score_samples(post_c2)), title='$f_2(\mathbf{x})=p(s(\mathbf{x})|y=2)$') #, savepath='figure/y2.png')
|
||||
#plot_simplex(ax3, post_c3, np.exp(kde3.score_samples(post_c3)), title='$f_3(\mathbf{x})=p(s(\mathbf{x})|y=3)$') #, savepath='figure/y3.png')
|
||||
|
||||
def mixture_(prevs, kdes):
|
||||
def m(p):
|
||||
total_density = 0
|
||||
for prev, kde in zip(prevs, kdes):
|
||||
log_density = kde.score_samples([p]).item()
|
||||
density = np.exp(log_density)
|
||||
density *= prev
|
||||
total_density += density
|
||||
#print(total_density)
|
||||
return total_density
|
||||
return m
|
||||
|
||||
title = '$p_{\mathbf{\\alpha}} = \sum_{i \in n} \\alpha_i p_i$'
|
||||
|
||||
plot_simplex_(ax4, mixture_(alpha, [kde1, kde2, kde3]), title=title, points=post_test)
|
||||
#mixture(alpha, [kde1, kde2, kde3])
|
||||
|
||||
#post_test = np.concatenate([post_test, np.eye(3, dtype=float)])
|
||||
#test_scores = sum(alphai*np.exp(kdei.score_samples(post_test)) for alphai, kdei in zip(alpha, [kde1,kde2,kde3]))
|
||||
#plot_simplex(ax4, post_test, test_scores, title=title, points=post_test)
|
||||
|
||||
#ternary.plt.show()
|
||||
ternary.plt.savefig('./simplex.png')
|
||||
|
||||
|
||||
import quapy as qp
|
||||
|
||||
|
||||
data = qp.datasets.fetch_twitter('wb', min_df=3, pickle=True, for_model_selection=False)
|
||||
|
||||
X, y = data.training.Xy
|
||||
|
||||
cls = LogisticRegression(C=0.0001, random_state=0)
|
||||
|
||||
Xtr, Xte, ytr, yte = train_test_split(X, y, train_size=0.7, stratify=y, random_state=0)
|
||||
|
||||
cls.fit(Xtr, ytr)
|
||||
|
||||
test = LabelledCollection(Xte, yte)
|
||||
test = test.sampling(100, *[0.2, 0.1, 0.7])
|
||||
|
||||
Xte, yte = test.Xy
|
||||
|
||||
post_c1 = cls.predict_proba(Xte[yte==0])
|
||||
post_c2 = cls.predict_proba(Xte[yte==1])
|
||||
post_c3 = cls.predict_proba(Xte[yte==2])
|
||||
|
||||
post_test = cls.predict_proba(Xte)
|
||||
print(post_test)
|
||||
alpha = qp.functional.prevalence_from_labels(yte, classes=[0, 1, 2])
|
||||
|
||||
#post_c1 = np.random.dirichlet([10,3,1], 30)
|
||||
#post_c2 = np.random.dirichlet([1,11,6], 30)
|
||||
#post_c3 = np.random.dirichlet([1,5,20], 30)
|
||||
#post_test = np.random.dirichlet([5,1,6], 100)
|
||||
#alpha = [0.5, 0.3, 0.2]
|
||||
|
||||
|
||||
print(f'test alpha {alpha}')
|
||||
plot_3class_problem(post_c1, post_c2, post_c3, post_test, alpha, bandwidth=0.1)
|
||||
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
import ternary
|
||||
import math
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import train_test_split, cross_val_predict
|
||||
from sklearn.neighbors import KernelDensity
|
||||
import plotly.figure_factory as ff
|
||||
|
||||
from quapy.data import LabelledCollection
|
||||
|
||||
scale = 100
|
||||
|
||||
|
||||
# con ternary (una lib de matplotlib) salen bien pero no puedo crear contornos, o no se
|
||||
# con plotly salen los contornos bien, pero es un poco un jaleo porque utiliza el navegador...
|
||||
|
||||
def plot_simplex_(ax, density, title='', fontsize=30, points=None):
|
||||
|
||||
tax = ternary.TernaryAxesSubplot(ax=ax, scale=scale)
|
||||
tax.heatmapf(density, boundary=True, style="triangular", colorbar=False, cmap='viridis') #cmap='magma')
|
||||
tax.boundary(linewidth=1.0)
|
||||
corner_fontsize = int(5*fontsize//6)
|
||||
tax.right_corner_label("$y=3$", fontsize=corner_fontsize)
|
||||
tax.top_corner_label("$y=2$", fontsize=corner_fontsize)
|
||||
tax.left_corner_label("$y=1$", fontsize=corner_fontsize)
|
||||
if title:
|
||||
tax.set_title(title, loc='center', y=-0.11, fontsize=fontsize)
|
||||
if points is not None:
|
||||
tax.scatter(points*scale, marker='o', color='w', alpha=0.25, zorder=10, s=5*scale)
|
||||
tax.get_axes().axis('off')
|
||||
tax.clear_matplotlib_ticks()
|
||||
|
||||
return tax
|
||||
|
||||
|
||||
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
def plot_3class_problem(post_c1, post_c2, post_c3, post_test, alpha, bandwidth):
|
||||
post_c1 = np.flip(post_c1, axis=1)
|
||||
post_c2 = np.flip(post_c2, axis=1)
|
||||
post_c3 = np.flip(post_c3, axis=1)
|
||||
post_test = np.flip(post_test, axis=1)
|
||||
|
||||
size_=10
|
||||
fig = ternary.plt.figure(figsize=(5*size_, 1*size_))
|
||||
fig.tight_layout()
|
||||
ax1 = fig.add_subplot(1, 4, 1)
|
||||
divider = make_axes_locatable(ax1)
|
||||
ax2 = fig.add_subplot(1, 4, 2)
|
||||
divider = make_axes_locatable(ax2)
|
||||
ax3 = fig.add_subplot(1, 4, 3)
|
||||
divider = make_axes_locatable(ax3)
|
||||
ax4 = fig.add_subplot(1, 4, 4)
|
||||
divider = make_axes_locatable(ax4)
|
||||
|
||||
kde1 = KernelDensity(bandwidth=bandwidth).fit(post_c1)
|
||||
kde2 = KernelDensity(bandwidth=bandwidth).fit(post_c2)
|
||||
kde3 = KernelDensity(bandwidth=bandwidth).fit(post_c3)
|
||||
|
||||
#post_c1 = np.concatenate([post_c1, np.eye(3, dtype=float)])
|
||||
#post_c2 = np.concatenate([post_c2, np.eye(3, dtype=float)])
|
||||
#post_c3 = np.concatenate([post_c3, np.eye(3, dtype=float)])
|
||||
|
||||
#plot_simplex_(ax1, lambda x:0, title='$f_1(\mathbf{x})=p(s(\mathbf{x})|y=1)$')
|
||||
#plot_simplex_(ax2, lambda x:0, title='$f_1(\mathbf{x})=p(s(\mathbf{x})|y=1)$')
|
||||
#plot_simplex_(ax3, lambda x:0, title='$f_1(\mathbf{x})=p(s(\mathbf{x})|y=1)$')
|
||||
def density(kde):
|
||||
def d(p):
|
||||
return np.exp(kde([p])).item()
|
||||
return d
|
||||
|
||||
plot_simplex_(ax1, density(kde1.score_samples), title='$p_1$')
|
||||
plot_simplex_(ax2, density(kde2.score_samples), title='$p_2$')
|
||||
plot_simplex_(ax3, density(kde3.score_samples), title='$p_3$')
|
||||
#plot_simplex(ax1, post_c1, np.exp(kde1.score_samples(post_c1)), title='$f_1(\mathbf{x})=p(s(\mathbf{x})|y=1)$') #, savepath='figure/y1.png')
|
||||
#plot_simplex(ax2, post_c2, np.exp(kde2.score_samples(post_c2)), title='$f_2(\mathbf{x})=p(s(\mathbf{x})|y=2)$') #, savepath='figure/y2.png')
|
||||
#plot_simplex(ax3, post_c3, np.exp(kde3.score_samples(post_c3)), title='$f_3(\mathbf{x})=p(s(\mathbf{x})|y=3)$') #, savepath='figure/y3.png')
|
||||
|
||||
def mixture_(prevs, kdes):
|
||||
def m(p):
|
||||
total_density = 0
|
||||
for prev, kde in zip(prevs, kdes):
|
||||
log_density = kde.score_samples([p]).item()
|
||||
density = np.exp(log_density)
|
||||
density *= prev
|
||||
total_density += density
|
||||
#print(total_density)
|
||||
return total_density
|
||||
return m
|
||||
|
||||
title = '' # r'$\mathbf{p}_{\mathbf{\\alpha}} = \sum_{i \in n} \\alpha_i p_i$'
|
||||
|
||||
plot_simplex_(ax4, mixture_(alpha, [kde1, kde2, kde3]), title=title, points=post_test)
|
||||
|
||||
#ternary.plt.show()
|
||||
ternary.plt.savefig(f'./plots_ieee/simplex_{dataset}.png', dpi=300)
|
||||
|
||||
|
||||
import quapy as qp
|
||||
|
||||
dataset = 'wa'
|
||||
data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True, for_model_selection=False)
|
||||
|
||||
Xtr, ytr = data.training.Xy
|
||||
Xte, yte = data.test.sampling(150, *[0.7, 0.2, 0.1], random_state=0).Xy
|
||||
|
||||
cls = LogisticRegression(C=0.0001, random_state=0)
|
||||
|
||||
post_tr = cross_val_predict(cls, Xtr, ytr, n_jobs=-1, method='predict_proba')
|
||||
post_c1 = post_tr[ytr==0]
|
||||
post_c2 = post_tr[ytr==1]
|
||||
post_c3 = post_tr[ytr==2]
|
||||
cls.fit(Xtr, ytr)
|
||||
|
||||
post_test = cls.predict_proba(Xte)
|
||||
|
||||
alpha = qp.functional.prevalence_from_labels(yte, classes=[0, 1, 2])
|
||||
|
||||
print(f'test alpha {alpha}')
|
||||
plot_3class_problem(post_c1, post_c2, post_c3, post_test, alpha, bandwidth=0.1)
|
||||
|
||||
Loading…
Reference in New Issue