training times added to globar report
This commit is contained in:
parent
498fd8b050
commit
93dd6cb1c1
examples
|
@ -1,5 +1,7 @@
|
||||||
import pickle
|
import pickle
|
||||||
import os
|
import os
|
||||||
|
from time import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
@ -38,9 +40,17 @@ def show_results(result_path):
|
||||||
df = pd.read_csv(result_path+'.csv', sep='\t')
|
df = pd.read_csv(result_path+'.csv', sep='\t')
|
||||||
pd.set_option('display.max_columns', None)
|
pd.set_option('display.max_columns', None)
|
||||||
pd.set_option('display.max_rows', None)
|
pd.set_option('display.max_rows', None)
|
||||||
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"], margins=True)
|
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE", "t_train"], margins=True)
|
||||||
print(pv)
|
print(pv)
|
||||||
|
|
||||||
|
def load_timings(result_path):
|
||||||
|
import pandas as pd
|
||||||
|
timings = defaultdict(lambda: {})
|
||||||
|
if not Path(result_path + '.csv').exists():
|
||||||
|
return timings
|
||||||
|
|
||||||
|
df = pd.read_csv(result_path+'.csv', sep='\t')
|
||||||
|
return timings | df.pivot_table(index='Dataset', columns='Method', values='t_train').to_dict()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
@ -53,8 +63,9 @@ if __name__ == '__main__':
|
||||||
os.makedirs(result_dir, exist_ok=True)
|
os.makedirs(result_dir, exist_ok=True)
|
||||||
|
|
||||||
global_result_path = f'{result_dir}/allmethods'
|
global_result_path = f'{result_dir}/allmethods'
|
||||||
|
timings = load_timings(global_result_path)
|
||||||
with open(global_result_path + '.csv', 'wt') as csv:
|
with open(global_result_path + '.csv', 'wt') as csv:
|
||||||
csv.write(f'Method\tDataset\tMAE\tMRAE\n')
|
csv.write(f'Method\tDataset\tMAE\tMRAE\tt_train\n')
|
||||||
|
|
||||||
for method_name, quantifier, param_grid in METHODS:
|
for method_name, quantifier, param_grid in METHODS:
|
||||||
|
|
||||||
|
@ -64,9 +75,6 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
for dataset in qp.datasets.UCI_MULTICLASS_DATASETS:
|
for dataset in qp.datasets.UCI_MULTICLASS_DATASETS:
|
||||||
|
|
||||||
if dataset in []:
|
|
||||||
continue
|
|
||||||
|
|
||||||
print('init', dataset)
|
print('init', dataset)
|
||||||
|
|
||||||
local_result_path = os.path.join(Path(global_result_path).parent, method_name + '_' + dataset + '.dataframe')
|
local_result_path = os.path.join(Path(global_result_path).parent, method_name + '_' + dataset + '.dataframe')
|
||||||
|
@ -88,7 +96,8 @@ if __name__ == '__main__':
|
||||||
modsel = GridSearchQ(
|
modsel = GridSearchQ(
|
||||||
quantifier, param_grid, protocol, refit=True, n_jobs=-1, verbose=1, error='mae'
|
quantifier, param_grid, protocol, refit=True, n_jobs=-1, verbose=1, error='mae'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
t_init = time()
|
||||||
try:
|
try:
|
||||||
modsel.fit(train)
|
modsel.fit(train)
|
||||||
|
|
||||||
|
@ -99,6 +108,8 @@ if __name__ == '__main__':
|
||||||
except:
|
except:
|
||||||
print('something went wrong... trying to fit the default model')
|
print('something went wrong... trying to fit the default model')
|
||||||
quantifier.fit(train)
|
quantifier.fit(train)
|
||||||
|
timings[method_name][dataset] = time() - t_init
|
||||||
|
|
||||||
|
|
||||||
protocol = UPP(test, repeats=n_bags_test)
|
protocol = UPP(test, repeats=n_bags_test)
|
||||||
report = qp.evaluation.evaluation_report(
|
report = qp.evaluation.evaluation_report(
|
||||||
|
@ -107,7 +118,7 @@ if __name__ == '__main__':
|
||||||
report.to_csv(local_result_path)
|
report.to_csv(local_result_path)
|
||||||
|
|
||||||
means = report.mean(numeric_only=True)
|
means = report.mean(numeric_only=True)
|
||||||
csv.write(f'{method_name}\t{dataset}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\n')
|
csv.write(f'{method_name}\t{dataset}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{timings[method_name][dataset]:.3f}\n')
|
||||||
csv.flush()
|
csv.flush()
|
||||||
|
|
||||||
show_results(global_result_path)
|
show_results(global_result_path)
|
Loading…
Reference in New Issue