some plots

This commit is contained in:
Alejandro Moreo Fernandez 2024-09-16 17:50:34 +02:00
parent ede214aa54
commit faba2494b2
1 changed files with 10 additions and 4 deletions

View File

@ -1,3 +1,5 @@
import os
import numpy as np
from sklearn.linear_model import LogisticRegression
from os.path import join
@ -5,7 +7,7 @@ import quapy as qp
from quapy.protocol import UPP
from quapy.method.aggregative import KDEyML
DEBUG = True
DEBUG = False
qp.environ["SAMPLE_SIZE"] = 100 if DEBUG else 500
val_repeats = 100 if DEBUG else 500
@ -21,7 +23,7 @@ if DEBUG:
bandwidth_range = np.linspace(0.01, 0.20, 10)
def datasets():
for dataset_name in qp.datasets.UCI_MULTICLASS_DATASETS[:4]:
for dataset_name in qp.datasets.UCI_MULTICLASS_DATASETS:
dataset = qp.datasets.fetch_UCIMulticlassDataset(dataset_name)
if DEBUG:
dataset = dataset.reduce(random_state=0)
@ -40,7 +42,8 @@ def experiment_dataset(dataset):
param_grid={'bandwidth': bandwidth_range},
protocol=UPP(train_va, repeats=val_repeats),
refit=False,
n_jobs=-1
n_jobs=-1,
verbose=True
).fit(train_tr)
chosen_bandwidth = modsel.best_params_['bandwidth']
modsel_choice = float(chosen_bandwidth)
@ -83,7 +86,10 @@ def plot_bandwidth(val_choice, test_results):
# Mostrar la gráfica
plt.grid(True)
plt.show()
# plt.show()
os.makedirs('./plots', exist_ok=True)
plt.savefig(f'./plots/{dataset_name}.png')
for dataset in datasets():