81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
import time
|
|
from functools import wraps
|
|
import os
|
|
from os.path import join
|
|
from result_table.src.table import Table
|
|
import numpy as np
|
|
from constants import *
|
|
|
|
def measuretime(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
start_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
end_time = time.time()
|
|
time_it_took = end_time - start_time
|
|
if isinstance(result, tuple):
|
|
return (*result, time_it_took)
|
|
else:
|
|
return result, time_it_took
|
|
return wrapper
|
|
|
|
|
|
def plot_bandwidth(dataset_name, test_results, bandwidths, triplet_list_results):
|
|
import matplotlib.pyplot as plt
|
|
|
|
print("PLOT", dataset_name)
|
|
print(dataset_name)
|
|
|
|
plt.figure(figsize=(8, 6))
|
|
|
|
# show test results
|
|
plt.plot(bandwidths, test_results, marker='o', color='k')
|
|
|
|
colors = plt.cm.tab10(np.linspace(0, 1, len(triplet_list_results)))
|
|
for i, (method_name, method_choice, method_time) in enumerate(triplet_list_results):
|
|
plt.axvline(x=method_choice, linestyle='--', label=method_name, color=colors[i])
|
|
|
|
# Agregar etiquetas y título
|
|
plt.xlabel('Bandwidth')
|
|
plt.ylabel('MAE')
|
|
plt.title(dataset_name)
|
|
|
|
# Mostrar la leyenda
|
|
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
|
|
|
# Mostrar la gráfica
|
|
plt.grid(True)
|
|
|
|
plotdir = './plots'
|
|
if DEBUG:
|
|
plotdir = './plots_debug'
|
|
os.makedirs(plotdir, exist_ok=True)
|
|
plt.tight_layout()
|
|
plt.savefig(f'{plotdir}/{dataset_name}.png')
|
|
plt.close()
|
|
|
|
def error_table(dataset_name, test_results, bandwidth_range, triplet_list_results):
|
|
best_bandwidth = bandwidth_range[np.argmin(test_results)]
|
|
best_score = np.min(test_results)
|
|
print(f'Method\tChoice\tAE\tTime')
|
|
table=Table(name=dataset_name)
|
|
table.format.with_mean=False
|
|
table.format.with_rank_mean = False
|
|
table.format.show_std = False
|
|
for method_name, method_choice, took in triplet_list_results:
|
|
if method_choice in bandwidth_range:
|
|
index = np.where(bandwidth_range == method_choice)[0][0]
|
|
method_score = test_results[index]
|
|
else:
|
|
method_score = 1
|
|
error = np.abs(best_score-method_score)
|
|
table.add(benchmark='Choice', method=method_name, v=method_choice)
|
|
table.add(benchmark='ScoreChoice', method=method_name, v=method_score)
|
|
table.add(benchmark='Best', method=method_name, v=best_bandwidth)
|
|
table.add(benchmark='ScoreBest', method=method_name, v=best_score)
|
|
table.add(benchmark='AE', method=method_name, v=error)
|
|
table.add(benchmark='Time', method=method_name, v=took)
|
|
outpath = './tables'
|
|
if DEBUG:
|
|
outpath = './tables_debug'
|
|
table.latexPDF(join(outpath, dataset_name+'.pdf'), transpose=True) |