QuaPy/Census/table.py

480 lines
16 KiB
Python

import numpy as np
from typing import Union, List
from collections.abc import Iterable
from dataclasses import dataclass
from scipy.stats import wilcoxon, ttest_ind_from_stats
import pandas as pd
import os
from pathlib import Path
@dataclass
class CellFormat:
mean_prec: int = 3
std_prec: int = 3
show_std: bool = True
remove_zero: bool = False
color: bool = True
maxtone: int = 50
class Cell:
def __init__(self, format: CellFormat, group: 'CellGroup'):
self.values = []
self.format = format
self.touch()
self.group = group
self.group.register_cell(self)
def __len__(self):
return len(self.values)
def mean(self):
if self.mean_ is None:
self.mean_ = np.mean(self.values)
return self.mean_
def std(self):
if self.std_ is None:
self.std_ = np.std(self.values)
return self.std_
def touch(self):
self.mean_ = None
self.std_ = None
def append(self, v: Union[float,Iterable]):
if isinstance(v, Iterable):
self.values.extend(v)
self.values.append(v)
self.touch()
def isEmpty(self):
return len(self)==0
def isBest(self):
best = self.group.best()
if best is not None:
return (best == self) or (np.isclose(best.mean(), self.mean()))
return False
def print_mean(self):
if self.isEmpty():
return ''
else:
return f'{self.mean():.{self.format.mean_prec}f}'
def print(self):
if self.isEmpty():
return ''
# mean
# ---------------------------------------------------
mean = self.print_mean()
if self.format.remove_zero:
mean = mean.replace('0.', '.')
# std ?
# ---------------------------------------------------
if self.format.show_std:
std = f' $\pm$ {self.std():.{self.format.std_prec}f}'
else:
std = ''
# bold or statistical test
# ---------------------------------------------------
if self.isBest():
str_cell = f'\\textbf{{{mean}{std}}}'
else:
comp_symbol = ''
pval = self.group.compare(self)
if pval is not None:
if 0.005 > pval:
comp_symbol = ''
elif 0.05 > pval >= 0.005:
comp_symbol = '$^{\dag}$'
elif pval >= 0.05:
comp_symbol = '${\ddag}$'
str_cell = f'{mean}{comp_symbol}{std}'
# color ?
# ---------------------------------------------------
if self.format.color:
str_cell += ' ' + self.group.color(self)
return str_cell
class CellGroup:
def __init__(self, lower_is_better=True, stat_test='wilcoxon', color_mode='local', color_global_min=None, color_global_max=None):
assert stat_test in ['wilcoxon', 'ttest', None], \
f"unknown {stat_test=}, valid ones are wilcoxon, ttest, or None"
assert color_mode in ['local', 'global'], \
f"unknown {color_mode=}, valid ones are local and global"
if (color_global_min is not None or color_global_max is not None) and color_mode=='local':
print('warning: color_global_min and color_global_max are only considered when color_mode==local')
self.cells = []
self.lower_is_better = lower_is_better
self.stat_test = stat_test
self.color_mode = color_mode
self.color_global_min = color_global_min
self.color_global_max = color_global_max
def register_cell(self, cell: Cell):
self.cells.append(cell)
def non_empty_cells(self):
return [c for c in self.cells if not c.isEmpty()]
def max(self):
cells = self.non_empty_cells()
if len(cells)>0:
return cells[np.argmax([c.mean() for c in cells])]
return None
def min(self):
cells = self.non_empty_cells()
if len(cells) > 0:
return cells[np.argmin([c.mean() for c in cells])]
return None
def best(self) -> Cell:
return self.min() if self.lower_is_better else self.max()
def worst(self) -> Cell:
return self.max() if self.lower_is_better else self.min()
def isEmpty(self):
return len(self.non_empty_cells())==0
def compare(self, cell: Cell):
best = self.best()
best_n = len(best)
cell_n = len(cell)
if best_n > 0 and cell_n > 0:
if self.stat_test == 'wilcoxon':
try:
_, p_val = wilcoxon(best.values, cell.values)
except ValueError:
p_val = None
return p_val
elif self.stat_test == 'ttest':
best_mean, best_std = best.mean(), best.std()
cell_mean, cell_std = cell.mean(), cell.std()
_, p_val = ttest_ind_from_stats(best_mean, best_std, best_n, cell_mean, cell_std, cell_n)
return p_val
elif self.stat_test is None:
return None
else:
raise ValueError(f'unknown statistical test {self.stat_test}')
else:
return None
def color(self, cell: Cell):
cell_mean = cell.mean()
if self.color_mode == 'local':
best = self.best()
worst = self.worst()
best_mean = best.mean()
worst_mean = worst.mean()
if best is None or worst is None or best_mean == worst_mean or cell.isEmpty():
return ''
# normalize val in [0,1]
maxval = max(best_mean, worst_mean)
minval = min(best_mean, worst_mean)
else:
maxval = self.color_global_max
minval = self.color_global_min
normval = (cell_mean - minval) / (maxval - minval)
if self.lower_is_better:
normval = 1 - normval
normval = np.clip(normval, 0, 1)
normval = normval * 2 - 1 # rescale to [-1,1]
if normval < 0:
color = 'red'
tone = cell.format.maxtone * (-normval)
else:
color = 'green'
tone = cell.format.maxtone * normval
return f'\cellcolor{{{color}!{int(tone)}}}'
class Table:
def __init__(self,
name,
benchmarks=None,
methods=None,
format:CellFormat=None,
lower_is_better=True,
stat_test='wilcoxon',
color_mode='local',
with_mean=True
):
self.name = name
self.benchmarks = [] if benchmarks is None else benchmarks
self.methods = [] if methods is None else methods
self.format = format if format is not None else CellFormat()
self.lower_is_better = lower_is_better
self.stat_test = stat_test
self.color_mode = color_mode
self.with_mean = with_mean
self.only_full_mean = True # if False, compute the mean of partially empty methods also
if self.color_mode == 'global':
self.color_global_min = 0
self.color_global_max = 1
else:
self.color_global_min = None
self.color_global_max = None
self.T = {}
self.groups = {}
def add(self, benchmark, method, v):
cell = self.get(benchmark, method)
cell.append(v)
def get_benchmarks(self):
return self.benchmarks
def get_methods(self):
return self.methods
def n_benchmarks(self):
return len(self.benchmarks)
def n_methods(self):
return len(self.methods)
def _new_group(self):
return CellGroup(self.lower_is_better, self.stat_test, color_mode=self.color_mode,
color_global_max=self.color_global_max, color_global_min=self.color_global_min)
def get(self, benchmark, method) -> Cell:
if benchmark not in self.benchmarks:
self.benchmarks.append(benchmark)
if benchmark not in self.groups:
self.groups[benchmark] = self._new_group()
if method not in self.methods:
self.methods.append(method)
b_idx = self.benchmarks.index(benchmark)
m_idx = self.methods.index(method)
idx = tuple((b_idx, m_idx))
if idx not in self.T:
self.T[idx] = Cell(self.format, group=self.groups[benchmark])
cell = self.T[idx]
return cell
def get_value(self, benchmark, method) -> float:
return self.get(benchmark, method).mean()
def get_benchmark(self, benchmark):
cells = [self.get(benchmark, method=m) for m in self.get_methods()]
cells = [c for c in cells if not c.isEmpty()]
return cells
def get_method(self, method):
cells = [self.get(benchmark=b, method=method) for b in self.get_benchmarks()]
cells = [c for c in cells if not c.isEmpty()]
return cells
def get_method_means(self, method_order):
mean_group = self._new_group()
cells = []
for method in method_order:
method_mean = Cell(self.format, group=mean_group)
for bench in self.get_benchmarks():
mean_value = self.get_value(benchmark=bench, method=method)
if not np.isnan(mean_value):
method_mean.append(mean_value)
cells.append(method_mean)
return cells
def get_benchmark_values(self, benchmark):
values = np.asarray([c.mean() for c in self.get_benchmark(benchmark)])
return values
def get_method_values(self, method):
values = np.asarray([c.mean() for c in self.get_method(method)])
return values
def all_mean(self):
values = [c.mean() for c in self.T.values() if not c.isEmpty()]
return np.mean(values)
def print(self): # todo: missing method names?
data_dict = {}
data_dict['Benchmark'] = [b for b in self.get_benchmarks()]
for method in self.get_methods():
data_dict[method] = [self.get(bench, method).print_mean() for bench in self.get_benchmarks()]
df = pd.DataFrame(data_dict)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
print(df.to_string(index=False))
def tabular(self, path=None, benchmark_replace=None, method_replace=None, benchmark_order=None, method_order=None, transpose=False):
if benchmark_replace is None:
benchmark_replace = {}
if method_replace is None:
method_replace = {}
if benchmark_order is None:
benchmark_order = self.get_benchmarks()
if method_order is None:
method_order = self.get_methods()
if transpose:
row_order, row_replace = method_order, method_replace
col_order, col_replace = benchmark_order, benchmark_replace
else:
row_order, row_replace = benchmark_order, benchmark_replace
col_order, col_replace = method_order, method_replace
n_cols = len(col_order)
add_mean_col = self.with_mean and transpose
add_mean_row = self.with_mean and not transpose
last_col_idx = n_cols+2 if add_mean_col else n_cols+1
if self.with_mean:
mean_cells = self.get_method_means(method_order)
lines = []
lines.append('\\begin{tabular}{|c' + '|c' * n_cols + ('||c' if add_mean_col else '') + "|}")
lines.append(f'\\cline{{2-{last_col_idx}}}')
l = '\multicolumn{1}{c|}{} & '
l += ' & '.join([col_replace.get(col, col) for col in col_order])
if add_mean_col:
l += ' & Ave.'
l += ' \\\\\\hline'
lines.append(l)
for i, row in enumerate(row_order):
rowname = row_replace.get(row, row)
l = rowname + ' & '
l += ' & '.join([
self.get(benchmark=col if transpose else row, method=row if transpose else col).print()
for col in col_order
])
if add_mean_col:
l+= ' & ' + mean_cells[i].print()
l += ' \\\\\\hline'
lines.append(l)
if add_mean_row:
lines.append('\hline')
l = 'Ave. & '
l+= ' & '.join([mean_cell.print() for mean_cell in mean_cells])
l += ' \\\\\\hline'
lines.append(l)
lines.append('\\end{tabular}')
tabular_tex = '\n'.join(lines)
if path is not None:
parent = Path(path).parent
if parent:
os.makedirs(parent, exist_ok=True)
with open(path, 'wt') as foo:
foo.write(tabular_tex)
return tabular_tex
def table(self, tabular_path, benchmark_replace=None, method_replace=None, resizebox=True, caption=None, label=None, benchmark_order=None, method_order=None, transpose=False):
if benchmark_replace is None:
benchmark_replace = {}
if method_replace is None:
method_replace = {}
lines = []
lines.append('\\begin{table}')
lines.append('\center')
if resizebox:
lines.append('\\resizebox{\\textwidth}{!}{%')
tabular_str = self.tabular(tabular_path, benchmark_replace, method_replace, benchmark_order, method_order, transpose)
if tabular_path is None:
lines.append(tabular_str)
else:
lines.append(f'\input{{tables/{Path(tabular_path).name}}}')
if resizebox:
lines.append('}%')
if caption is None:
caption = tabular_path.replace('_', '\_')
lines.append(f'\caption{{{caption}}}')
if label is not None:
lines.append(f'\label{{{label}}}')
lines.append('\end{table}')
table_tex = '\n'.join(lines)
return table_tex
def document(self, tex_path, tabular_dir='tables', *args, **kwargs):
Table.Document(tex_path, tables=[self], tabular_dir=tabular_dir, *args, **kwargs)
def latexPDF(self, pdf_path, tabular_dir='tables', *args, **kwargs):
return Table.LatexPDF(pdf_path, tables=[self], tabular_dir=tabular_dir, *args, **kwargs)
@classmethod
def Document(self, tex_path, tables:List['Table'], tabular_dir='tables', landscape=True, *args, **kwargs):
lines = []
lines.append('\\documentclass[10pt,a4paper]{article}')
lines.append('\\usepackage[utf8]{inputenc}')
lines.append('\\usepackage{amsmath}')
lines.append('\\usepackage{amsfonts}')
lines.append('\\usepackage{amssymb}')
lines.append('\\usepackage{graphicx}')
lines.append('\\usepackage{xcolor}')
lines.append('\\usepackage{colortbl}')
if landscape:
lines.append('\\usepackage[landscape]{geometry}')
lines.append('')
lines.append('\\begin{document}')
for table in tables:
lines.append('')
lines.append(table.table(os.path.join(Path(tex_path).parent, tabular_dir, table.name + '_table.tex'), *args, **kwargs))
lines.append('\n\\newpage\n')
lines.append('\\end{document}')
document = '\n'.join(lines)
parent = Path(tex_path).parent
if parent:
os.makedirs(parent, exist_ok=True)
with open(tex_path, 'wt') as foo:
foo.write(document)
return document
@classmethod
def LatexPDF(cls, pdf_path: str, tables:List['Table'], tabular_dir: str = 'tables', *args, **kwargs):
assert pdf_path.endswith('.pdf'), f'{pdf_path=} does not seem a valid name for a pdf file'
tex_path = pdf_path.replace('.pdf', '.tex')
cls.Document(tex_path, tables, tabular_dir, *args, **kwargs)
dir = Path(pdf_path).parent
pwd = os.getcwd()
print('currently in', pwd)
print("[Tables Done] runing latex")
os.chdir(dir)
os.system('pdflatex ' + Path(tex_path).name)
basename = Path(tex_path).name.replace('.tex', '')
os.system(f'rm {basename}.aux {basename}.log')
os.chdir(pwd)
print('[Done]')