some plots
This commit is contained in:
parent
ede214aa54
commit
faba2494b2
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from os.path import join
|
||||
|
@ -5,7 +7,7 @@ import quapy as qp
|
|||
from quapy.protocol import UPP
|
||||
from quapy.method.aggregative import KDEyML
|
||||
|
||||
DEBUG = True
|
||||
DEBUG = False
|
||||
|
||||
qp.environ["SAMPLE_SIZE"] = 100 if DEBUG else 500
|
||||
val_repeats = 100 if DEBUG else 500
|
||||
|
@ -21,7 +23,7 @@ if DEBUG:
|
|||
bandwidth_range = np.linspace(0.01, 0.20, 10)
|
||||
|
||||
def datasets():
|
||||
for dataset_name in qp.datasets.UCI_MULTICLASS_DATASETS[:4]:
|
||||
for dataset_name in qp.datasets.UCI_MULTICLASS_DATASETS:
|
||||
dataset = qp.datasets.fetch_UCIMulticlassDataset(dataset_name)
|
||||
if DEBUG:
|
||||
dataset = dataset.reduce(random_state=0)
|
||||
|
@ -40,7 +42,8 @@ def experiment_dataset(dataset):
|
|||
param_grid={'bandwidth': bandwidth_range},
|
||||
protocol=UPP(train_va, repeats=val_repeats),
|
||||
refit=False,
|
||||
n_jobs=-1
|
||||
n_jobs=-1,
|
||||
verbose=True
|
||||
).fit(train_tr)
|
||||
chosen_bandwidth = modsel.best_params_['bandwidth']
|
||||
modsel_choice = float(chosen_bandwidth)
|
||||
|
@ -83,7 +86,10 @@ def plot_bandwidth(val_choice, test_results):
|
|||
|
||||
# Mostrar la gráfica
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
# plt.show()
|
||||
os.makedirs('./plots', exist_ok=True)
|
||||
plt.savefig(f'./plots/{dataset_name}.png')
|
||||
|
||||
|
||||
|
||||
for dataset in datasets():
|
||||
|
|
Loading…
Reference in New Issue