adding one plot, the error by drift plot... have to fix the legend though

This commit is contained in:
Alejandro Moreo Fernandez 2024-04-09 16:52:53 +02:00
parent 49a8cf3b0d
commit 9555c4a731
2 changed files with 60 additions and 16 deletions

View File

@ -49,6 +49,8 @@ class Benchmark(ABC):
makedirs(join(home_dir, 'tables'))
makedirs(join(home_dir, 'plots'))
self.train_prevalence = {}
def _run_id(self, method: MethodDescriptor, dataset: str):
sep = Benchmark.ID_SEPARATOR
assert sep not in method.id, \
@ -104,10 +106,34 @@ class Benchmark(ABC):
Table.LatexPDF(join(self.home_dir, 'tables', 'results.pdf'), list(tables.values()))
def gen_plots(self):
pass
def gen_plots(self, results, metrics=None):
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 11})
def show_report(self, method, dataset, report: pd.DataFrame):
if metrics is None:
metrics = ['ae']
for metric in metrics:
method_names, true_prevs, estim_prevs, train_prevs = [], [], [], []
skip=False
for (method, dataset, result) in results:
method_names.append(method.name)
true_prevs.append(np.vstack(result['true-prev'].values))
estim_prevs.append(np.vstack(result['estim-prev'].values))
train_prevs.append(self.get_training_prevalence(dataset))
if not skip:
path = join(self.home_dir, 'plots', f'err_by_drift_{metric}.pdf')
qp.plot.error_by_drift(method_names, true_prevs, estim_prevs, train_prevs, error_name=metric, n_bins=20, savepath=path)
def _show_report(self, method, dataset, report: pd.DataFrame):
id = method.id
MAE = report['mae'].mean()
mae_std = report['mae'].std()
@ -146,19 +172,20 @@ class Benchmark(ABC):
seed=0,
asarray=False
)
results += [(method, dataset, result) for (method, dataset), result in zip(pending_job_args, remaining_results)]
results += [
(method, dataset, result) for (method, dataset), result in zip(pending_job_args, remaining_results)
]
# print results
for method, dataset, result in results:
self.show_report(method, dataset, result)
self._show_report(method, dataset, result)
self.gen_tables(results)
self.gen_plots()
# def gen_plots(self, methods=None):
# if methods is None:
self.gen_plots(results)
@abstractmethod
def get_training_prevalence(self, dataset: str):
...
def __add__(self, other: 'Benchmark'):
return CombinedBenchmark(self, other, self.n_jobs)
@ -192,6 +219,10 @@ class TypicalBenchmark(Benchmark):
def get_sample_size(self)-> int:
...
@abstractmethod
def get_training(self, dataset:str)-> LabelledCollection:
...
@abstractmethod
def get_trModsel_valprotModsel_trEval_teprotEval(self, dataset:str)->\
(LabelledCollection, AbstractProtocol, LabelledCollection, AbstractProtocol):
@ -212,7 +243,8 @@ class TypicalBenchmark(Benchmark):
with qp.util.temp_seed(random_state):
# data split
trModSel, valprotModSel, trEval, teprotEval = self.get_trModsel_valprotModsel_trEval_teprotEval(dataset)
trModSel, valprotModSel, trEval, teprotEval = self.get_trModsel_valprotModsel_trEval_teprotEval(dataset)
self.train_prevalence[dataset] = trEval.prevalence()
# model selection
modsel = GridSearchQ(
@ -247,6 +279,12 @@ class TypicalBenchmark(Benchmark):
return report
def get_training_prevalence(self, dataset: str):
if not dataset in self.train_prevalence:
training = self.get_training(dataset)
self.train_prevalence[dataset] = training.prevalence()
return self.train_prevalence[dataset]
class UCIBinaryBenchmark(TypicalBenchmark):
@ -259,6 +297,9 @@ class UCIBinaryBenchmark(TypicalBenchmark):
testprotModsel = APP(teEval, n_prevalences=21, repeats=100)
return trModsel, valprotModsel, trEval, testprotModsel
def get_training(self, dataset:str) -> LabelledCollection:
return qp.datasets.fetch_UCIBinaryDataset(dataset).training
def get_sample_size(self) -> int:
return 100
@ -284,6 +325,9 @@ class UCIMultiBenchmark(TypicalBenchmark):
testprotModsel = UPP(teEval, repeats=1000)
return trModsel, valprotModsel, trEval, testprotModsel
def get_training(self, dataset:str) -> LabelledCollection:
return qp.datasets.fetch_UCIMulticlassDataset(dataset).training
def get_sample_size(self) -> int:
return 500
@ -291,6 +335,7 @@ class UCIMultiBenchmark(TypicalBenchmark):
return 'mae'
if __name__ == '__main__':
from quapy.benchmarking.typical import *

View File

@ -259,7 +259,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
data = _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error, y_error, method_order)
if method_order is None:
method_order = method_names
method_order = np.unique(method_names)
_set_colors(ax, n_methods=len(method_order))
@ -329,10 +329,9 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
if show_legend:
fig.legend(loc='lower center',
bbox_to_anchor=(1, 0.5),
ncol=(len(method_names)+1)//2)
# fig.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=(len(method_order)//2)+1)
fig.legend(loc='upper right', bbox_to_anchor=(1, 0.6))
_save_or_show(savepath)