Merge pull request #3 from lorenzovolpi/report_rework

Report rework
This commit is contained in:
Lorenzo Volpi 2023-10-31 15:08:16 +01:00 committed by GitHub
commit 9c48c725e7
13 changed files with 574 additions and 402 deletions

1
.vscode/launch.json vendored
View File

@ -5,6 +5,7 @@
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "main", "name": "main",
"type": "python", "type": "python",

View File

@ -2,17 +2,24 @@ debug_conf: &debug_conf
global: global:
METRICS: METRICS:
- acc - acc
DATASET_N_PREVS: 1 DATASET_N_PREVS: 5
DATASET_PREVS:
- 0.5
- 0.1
datasets: confs:
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
- DATASET_NAME: imdb - DATASET_NAME: imdb
plot_confs: plot_confs:
debug: debug:
PLOT_ESTIMATORS: PLOT_ESTIMATORS:
# - mul_sld_bcts - ref
- atc_mc
- atc_ne
PLOT_STDEV: true
debug_plus:
PLOT_ESTIMATORS:
- mul_sld_bcts
- mul_sld - mul_sld
- ref - ref
- atc_mc - atc_mc
@ -23,11 +30,15 @@ test_conf: &test_conf
METRICS: METRICS:
- acc - acc
- f1 - f1
DATASET_N_PREVS: 3 DATASET_N_PREVS: 2
DATASET_PREVS:
- 0.5
- 0.1
datasets: confs:
- DATASET_NAME: rcv1 # - DATASET_NAME: rcv1
DATASET_TARGET: CCAT # DATASET_TARGET: CCAT
- DATASET_NAME: imdb
plot_confs: plot_confs:
best_vs_atc: best_vs_atc:
@ -49,15 +60,15 @@ main_conf: &main_conf
- f1 - f1
DATASET_N_PREVS: 9 DATASET_N_PREVS: 9
datasets: confs:
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
confs_bck:
- DATASET_NAME: imdb - DATASET_NAME: imdb
datasets_bck:
- DATASET_NAME: rcv1 - DATASET_NAME: rcv1
DATASET_TARGET: GCAT DATASET_TARGET: GCAT
- DATASET_NAME: rcv1 - DATASET_NAME: rcv1
DATASET_TARGET: MCAT DATASET_TARGET: MCAT
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
plot_confs: plot_confs:
gs_vs_atc: gs_vs_atc:
@ -100,4 +111,4 @@ main_conf: &main_conf
- atc_ne - atc_ne
- doc_feat - doc_feat
exec: *main_conf exec: *debug_conf

12
poetry.lock generated
View File

@ -443,6 +443,16 @@ files = [
{file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"},
] ]
[[package]]
name = "logging"
version = "0.4.9.6"
description = "A logging module for Python"
optional = false
python-versions = "*"
files = [
{file = "logging-0.4.9.6.tar.gz", hash = "sha256:26f6b50773f085042d301085bd1bf5d9f3735704db9f37c1ce6d8b85c38f2417"},
]
[[package]] [[package]]
name = "markupsafe" name = "markupsafe"
version = "2.1.3" version = "2.1.3"
@ -1261,4 +1271,4 @@ test = ["pytest", "pytest-cov"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "c98b7510ac055b667340b52e1b0b0777370e68d325d3149cb1fef42b6f1ec50a" content-hash = "54d9922f6d48a46f554a6b350ce09d668a88755efb1fbf295f8f8a0a411bdef2"

View File

@ -11,6 +11,7 @@ quapy = "^0.1.7"
pandas = "^2.0.3" pandas = "^2.0.3"
jinja2 = "^3.1.2" jinja2 = "^3.1.2"
pyyaml = "^6.0.1" pyyaml = "^6.0.1"
logging = "^0.4.9.6"
[tool.poetry.scripts] [tool.poetry.scripts]
main = "quacc.main:main" main = "quacc.main:main"

View File

@ -39,14 +39,13 @@ class Dataset:
self._target = target self._target = target
self.prevs = None self.prevs = None
self.n_prevs = n_prevalences
if prevs is not None: if prevs is not None:
prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0]) prevs = np.unique([p for p in prevs if p > 0.0 and p < 1.0])
if prevs.shape[0] > 0: if prevs.shape[0] > 0:
self.prevs = np.sort(prevs) self.prevs = np.sort(prevs)
self.n_prevs = self.prevs.shape[0] self.n_prevs = self.prevs.shape[0]
self.n_prevs = n_prevalences
def __spambase(self): def __spambase(self):
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
@ -88,7 +87,7 @@ class Dataset:
return DatasetSample(train, val, test) return DatasetSample(train, val, test)
def get(self) -> List[DatasetSample]: def get(self) -> List[DatasetSample]:
all_train, test = { (all_train, test) = {
"spambase": self.__spambase, "spambase": self.__spambase,
"imdb": self.__imdb, "imdb": self.__imdb,
"rcv1": self.__rcv1, "rcv1": self.__rcv1,
@ -108,7 +107,7 @@ class Dataset:
at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs) at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevs)
datasets = [] datasets = []
for p in prevs: for p in 1.0 - prevs:
all_train_sampled = all_train.sampling(at_size, p, random_state=0) all_train_sampled = all_train.sampling(at_size, p, random_state=0)
train, validation = all_train_sampled.split_stratified( train, validation = all_train_sampled.split_stratified(
train_prop=TRAIN_VAL_PROP, random_state=0 train_prop=TRAIN_VAL_PROP, random_state=0
@ -122,10 +121,11 @@ class Dataset:
@property @property
def name(self): def name(self):
if self._name == "rcv1": return (
return f"{self._name}_{self._target}" f"{self._name}_{self._target}_{self.n_prevs}prevs"
else: if self._name == "rcv1"
return self._name else f"{self._name}_{self.n_prevs}prevs"
)
# >>> fetch_rcv1().target_names # >>> fetch_rcv1().target_names

View File

@ -1,46 +1,53 @@
import yaml import collections as C
import copy
from typing import Any
defalut_env = { import yaml
"DATASET_NAME": "rcv1",
"DATASET_TARGET": "CCAT",
"METRICS": ["acc", "f1"],
"COMP_ESTIMATORS": [],
"PLOT_ESTIMATORS": [],
"PLOT_STDEV": False,
"DATASET_N_PREVS": 9,
"DATASET_PREVS": None,
"OUT_DIR_NAME": "output",
"OUT_DIR": None,
"PLOT_DIR_NAME": "plot",
"PLOT_OUT_DIR": None,
"DATASET_DIR_UPDATE": False,
"PROTOCOL_N_PREVS": 21,
"PROTOCOL_REPEATS": 100,
"SAMPLE_SIZE": 1000,
}
class environ: class environ:
_instance = None _instance = None
_default_env = {
"DATASET_NAME": None,
"DATASET_TARGET": None,
"METRICS": [],
"COMP_ESTIMATORS": [],
"DATASET_N_PREVS": 9,
"DATASET_PREVS": None,
"OUT_DIR_NAME": "output",
"OUT_DIR": None,
"PLOT_DIR_NAME": "plot",
"PLOT_OUT_DIR": None,
"DATASET_DIR_UPDATE": False,
"PROTOCOL_N_PREVS": 21,
"PROTOCOL_REPEATS": 100,
"SAMPLE_SIZE": 1000,
"PLOT_ESTIMATORS": [],
"PLOT_STDEV": False,
}
_keys = list(_default_env.keys())
def __init__(self, **kwargs): def __init__(self):
self.exec = [] self.exec = []
self.confs = [] self.confs = []
self._default = kwargs
self.__setdict(kwargs)
self.load_conf() self.load_conf()
self._stack = C.deque([self.__getdict()])
def __setdict(self, d): def __setdict(self, d):
for k, v in d.items(): for k, v in d.items():
self.__setattr__(k, v) super().__setattr__(k, v)
if len(self.PLOT_ESTIMATORS) == 0:
self.PLOT_ESTIMATORS = self.COMP_ESTIMATORS
def __class_getitem__(cls, k): def __getdict(self):
env = cls.get() return {k: self.__getattribute__(k) for k in environ._keys}
return env.__getattribute__(k)
def __setattr__(self, __name: str, __value: Any) -> None:
if __name in environ._keys:
self._stack[-1][__name] = __value
super().__setattr__(__name, __value)
def load_conf(self): def load_conf(self):
self.__setdict(environ._default_env)
with open("conf.yaml", "r") as f: with open("conf.yaml", "r") as f:
confs = yaml.safe_load(f)["exec"] confs = yaml.safe_load(f)["exec"]
@ -50,31 +57,30 @@ class environ:
_estimators = _estimators.union(set(pc["PLOT_ESTIMATORS"])) _estimators = _estimators.union(set(pc["PLOT_ESTIMATORS"]))
_global["COMP_ESTIMATORS"] = list(_estimators) _global["COMP_ESTIMATORS"] = list(_estimators)
self.__setdict(_global)
self.confs = confs["confs"]
self.plot_confs = confs["plot_confs"] self.plot_confs = confs["plot_confs"]
for dataset in confs["datasets"]:
self.confs.append(_global | dataset)
def get_confs(self): def get_confs(self):
self._stack.append(None)
for _conf in self.confs: for _conf in self.confs:
self.__setdict(self._default) self._stack.pop()
self.__setdict(self._stack[-1])
self.__setdict(_conf) self.__setdict(_conf)
if "DATASET_TARGET" not in _conf: self._stack.append(self.__getdict())
self.DATASET_TARGET = None
name = self.DATASET_NAME yield copy.deepcopy(self._stack[-1])
if self.DATASET_TARGET is not None:
name += f"_{self.DATASET_TARGET}"
name += f"_{self.DATASET_N_PREVS}prevs"
yield name self._stack.pop()
def get_plot_confs(self): def get_plot_confs(self):
self._stack.append(None)
for k, pc in self.plot_confs.items(): for k, pc in self.plot_confs.items():
if "PLOT_ESTIMATORS" in pc: self._stack.pop()
self.PLOT_ESTIMATORS = pc["PLOT_ESTIMATORS"] self.__setdict(self._stack[-1])
if "PLOT_STDEV" in pc: self.__setdict(pc)
self.PLOT_STDEV = pc["PLOT_STDEV"] self._stack.append(self.__getdict())
name = self.DATASET_NAME name = self.DATASET_NAME
if self.DATASET_TARGET is not None: if self.DATASET_TARGET is not None:
@ -82,5 +88,31 @@ class environ:
name += f"_{k}" name += f"_{k}"
yield name yield name
self._stack.pop()
env = environ(**defalut_env) @property
def current(self):
return copy.deepcopy(self.__getdict())
env = environ()
if __name__ == "__main__":
stack = C.deque()
stack.append(-1)
def __gen(stack: C.deque):
stack.append(None)
for i in range(5):
stack.pop()
stack.append(i)
yield stack[-1]
stack.pop()
print(stack)
for i in __gen(stack):
print(stack, i)
print(stack)

View File

@ -86,6 +86,7 @@ def atc_mc(
protocol: AbstractStochasticSeededProtocol, protocol: AbstractStochasticSeededProtocol,
predict_method="predict_proba", predict_method="predict_proba",
): ):
"""garg"""
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
## Load ID validation data probs and labels ## Load ID validation data probs and labels
@ -124,6 +125,7 @@ def atc_ne(
protocol: AbstractStochasticSeededProtocol, protocol: AbstractStochasticSeededProtocol,
predict_method="predict_proba", predict_method="predict_proba",
): ):
"""garg"""
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
## Load ID validation data probs and labels ## Load ID validation data probs and labels

View File

@ -1,6 +1,6 @@
import multiprocessing import multiprocessing
import time import time
import traceback from traceback import print_exception as traceback
from typing import List from typing import List
import pandas as pd import pandas as pd
@ -11,11 +11,10 @@ from quacc.environment import env
from quacc.evaluation import baseline, method from quacc.evaluation import baseline, method
from quacc.evaluation.report import CompReport, DatasetReport, EvaluationReport from quacc.evaluation.report import CompReport, DatasetReport, EvaluationReport
from quacc.evaluation.worker import estimate_worker from quacc.evaluation.worker import estimate_worker
from quacc.logging import Logger from quacc.logger import Logger
pd.set_option("display.float_format", "{:.4f}".format) pd.set_option("display.float_format", "{:.4f}".format)
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
log = Logger.logger()
class CompEstimator: class CompEstimator:
@ -43,6 +42,7 @@ CE = CompEstimator
def evaluate_comparison( def evaluate_comparison(
dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"] dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
) -> EvaluationReport: ) -> EvaluationReport:
log = Logger.logger()
# with multiprocessing.Pool(1) as pool: # with multiprocessing.Pool(1) as pool:
with multiprocessing.Pool(len(estimators)) as pool: with multiprocessing.Pool(len(estimators)) as pool:
dr = DatasetReport(dataset.name) dr = DatasetReport(dataset.name)

View File

@ -9,68 +9,48 @@ from quacc.environment import env
from quacc.utils import fmt_line_md from quacc.utils import fmt_line_md
def _get_metric(metric: str):
return slice(None) if metric is None else metric
def _get_estimators(estimators: List[str], cols: np.ndarray):
return slice(None) if estimators is None else cols[np.in1d(cols, estimators)]
class EvaluationReport: class EvaluationReport:
def __init__(self, name=None): def __init__(self, name=None):
self._prevs = [] self.data: pd.DataFrame = None
self._dict = {}
self.fit_score = None self.fit_score = None
self.name = name if name is not None else "default" self.name = name if name is not None else "default"
def append_row(self, basep: np.ndarray | Tuple, **row): def append_row(self, basep: np.ndarray | Tuple, **row):
bp = basep[1] bp = basep[1]
self._prevs.append(bp) _keys, _values = zip(*row.items())
for k, v in row.items(): # _keys = list(row.keys())
if k not in self._dict: # _values = list(row.values())
self._dict[k] = {}
if bp not in self._dict[k]: if self.data is None:
self._dict[k][bp] = [] _idx = 0
self._dict[k][bp] = np.append(self._dict[k][bp], [v]) self.data = pd.DataFrame(
{k: [v] for k, v in row.items()},
index=pd.MultiIndex.from_tuples([(bp, _idx)]),
columns=_keys,
)
return
_idx = len(self.data.loc[(bp,), :]) if (bp,) in self.data.index else 0
not_in_data = np.setdiff1d(list(row.keys()), self.data.columns.unique(0))
self.data.loc[:, not_in_data] = np.nan
self.data.loc[(bp, _idx), :] = row
return
@property @property
def columns(self): def columns(self) -> np.ndarray:
return self._dict.keys() return self.data.columns.unique(0)
@property @property
def prevs(self): def prevs(self):
return np.sort(np.unique([list(self._dict[_k].keys()) for _k in self._dict])) return np.sort(self.data.index.unique(0))
# def group_by_prevs(self, metric: str = None, estimators: List[str] = None):
# if self._g_dict is None:
# self._g_prevs = []
# self._g_dict = {k: [] for k in self._dict.keys()}
# for col, vals in self._dict.items():
# col_grouped = {}
# for bp, v in zip(self._prevs, vals):
# if bp not in col_grouped:
# col_grouped[bp] = []
# col_grouped[bp].append(v)
# self._g_dict[col] = [
# vs
# for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1])
# ]
# self._g_prevs = sorted(
# [(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()],
# key=lambda bp: bp[1],
# )
# fg_dict = _filter_dict(self._g_dict, metric, estimators)
# return self._g_prevs, fg_dict
# def merge(self, other):
# if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)):
# raise ValueError("other has not same base prevalences of self")
# inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys()))
# if len(inters_keys) > 0:
# raise ValueError(f"self and other have matching keys {str(inters_keys)}.")
# report = EvaluationReport()
# report._prevs = self._prevs
# report._dict = self._dict | other._dict
# return report
class CompReport: class CompReport:
@ -82,22 +62,16 @@ class CompReport:
valid_prev=None, valid_prev=None,
times=None, times=None,
): ):
all_prevs = np.array([er.prevs for er in reports]) self._data = (
if not np.all(all_prevs == all_prevs[0, :], axis=0).all(): pd.concat(
raise ValueError( [er.data for er in reports],
"Not all evaluation reports have the same base prevalences" keys=[er.name for er in reports],
axis=1,
) )
uq_names, name_c = np.unique([er.name for er in reports], return_counts=True) .swaplevel(0, 1, axis=1)
if np.sum(name_c) > uq_names.shape[0]: .sort_index(axis=1, level=0, sort_remaining=False)
_matching = uq_names[[c > 1 for c in name_c]] .sort_index(axis=0, level=0)
raise ValueError( )
f"Evaluation reports have matching names: {_matching.tolist()}."
)
all_dicts = [{(k, er.name): v for k, v in er._dict.items()} for er in reports]
self._dict = {}
for d in all_dicts:
self._dict = self._dict | d
self.fit_scores = { self.fit_scores = {
er.name: er.fit_score for er in reports if er.fit_score is not None er.name: er.fit_score for er in reports if er.fit_score is not None
@ -107,177 +81,196 @@ class CompReport:
self.times = times self.times = times
@property @property
def prevs(self): def prevs(self) -> np.ndarray:
return np.sort(np.unique([list(self._dict[_k].keys()) for _k in self._dict])) return np.sort(self._data.index.unique(0))
@property @property
def cprevs(self): def np_prevs(self) -> np.ndarray:
return np.around([(1.0 - p, p) for p in self.prevs], decimals=2) return np.around([(1.0 - p, p) for p in self.prevs], decimals=2)
def data(self, metric: str = None, estimators: List[str] = None) -> dict: def data(self, metric: str = None, estimators: List[str] = None) -> pd.DataFrame:
f_dict = self._dict.copy() _metric = _get_metric(metric)
if metric is not None: _estimators = _get_estimators(estimators, self._data.columns.unique(1))
f_dict = {(c0, c1): ls for ((c0, c1), ls) in f_dict.items() if c0 == metric} f_data: pd.DataFrame = self._data.copy().loc[:, (_metric, _estimators)]
if estimators is not None:
f_dict = {
(c0, c1): ls for ((c0, c1), ls) in f_dict.items() if c1 in estimators
}
if (metric, estimators) != (None, None):
f_dict = {c1: ls for ((c0, c1), ls) in f_dict.items()}
return f_dict if len(f_data.columns.unique(0)) == 1:
f_data = f_data.droplevel(level=0, axis=1)
def group_by_shift(self, metric: str = None, estimators: List[str] = None): return f_data
f_dict = self.data(metric=metric, estimators=estimators)
shift_prevs = np.around( def shift_data(
np.absolute(self.prevs - self.train_prev[1]), decimals=2 self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
shift_idx_0 = np.around(
np.abs(
self._data.index.get_level_values(0).to_numpy() - self.train_prev[1]
),
decimals=2,
) )
shift_dict = {col: {sp: [] for sp in shift_prevs} for col in f_dict.keys()}
for col, vals in f_dict.items():
for sp, bp in zip(shift_prevs, self.prevs):
shift_dict[col][sp] = np.concatenate(
[shift_dict[col][sp], f_dict[col][bp]]
)
return np.sort(np.unique(shift_prevs)), shift_dict shift_idx_1 = np.empty(shape=shift_idx_0.shape, dtype="<i4")
for _id in np.unique(shift_idx_0):
_wh = np.where(shift_idx_0 == _id)[0]
shift_idx_1[_wh] = np.arange(_wh.shape[0], dtype="<i4")
def avg_by_prevs(self, metric: str = None, estimators: List[str] = None): shift_data = self._data.copy()
shift_data.index = pd.MultiIndex.from_arrays([shift_idx_0, shift_idx_1])
shift_data.sort_index(axis=0, level=0)
_metric = _get_metric(metric)
_estimators = _get_estimators(estimators, shift_data.columns.unique(1))
shift_data: pd.DataFrame = shift_data.loc[:, (_metric, _estimators)]
if len(shift_data.columns.unique(0)) == 1:
shift_data = shift_data.droplevel(level=0, axis=1)
return shift_data
def avg_by_prevs(
self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
f_dict = self.data(metric=metric, estimators=estimators) f_dict = self.data(metric=metric, estimators=estimators)
return { return f_dict.groupby(level=0).mean()
col: np.array([np.mean(vals[bp]) for bp in self.prevs])
for col, vals in f_dict.items()
}
def stdev_by_prevs(self, metric: str = None, estimators: List[str] = None): def stdev_by_prevs(
self, metric: str = None, estimators: List[str] = None
) -> pd.DataFrame:
f_dict = self.data(metric=metric, estimators=estimators) f_dict = self.data(metric=metric, estimators=estimators)
return { return f_dict.groupby(level=0).std()
col: np.array([np.std(vals[bp]) for bp in self.prevs])
for col, vals in f_dict.items()
}
def avg_all(self, metric: str = None, estimators: List[str] = None): def table(self, metric: str = None, estimators: List[str] = None) -> pd.DataFrame:
f_dict = self.data(metric=metric, estimators=estimators) f_data = self.data(metric=metric, estimators=estimators)
return { avg_p = f_data.groupby(level=0).mean()
col: [np.mean(np.concatenate(list(vals.values())))] avg_p.loc["avg", :] = f_data.mean()
for col, vals in f_dict.items() return avg_p
}
def get_dataframe(self, metric="acc", estimators=None):
avg_dict = self.avg_by_prevs(metric=metric, estimators=estimators)
all_dict = self.avg_all(metric=metric, estimators=estimators)
for col in avg_dict.keys():
avg_dict[col] = np.append(avg_dict[col], all_dict[col])
return pd.DataFrame(
avg_dict,
index=self.prevs.tolist() + ["tot"],
columns=avg_dict.keys(),
)
def get_plots( def get_plots(
self, self, mode="delta", metric="acc", estimators=None, conf="default", stdev=False
modes=["delta", "diagonal", "shift"], ) -> List[Tuple[str, Path]]:
metric="acc", if mode == "delta":
estimators=None, avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
conf="default", return plot.plot_delta(
stdev=False, base_prevs=self.np_prevs,
) -> Path: columns=avg_data.columns.to_numpy(),
pps = [] data=avg_data.T.to_numpy(),
for mode in modes: metric=metric,
pp = [] name=conf,
if mode == "delta": train_prev=self.train_prev,
f_dict = self.avg_by_prevs(metric=metric, estimators=estimators) )
_pp0 = plot.plot_delta( elif mode == "delta_stdev":
self.cprevs, avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
f_dict, st_data = self.stdev_by_prevs(metric=metric, estimators=estimators)
metric=metric, return plot.plot_delta(
name=conf, base_prevs=self.np_prevs,
train_prev=self.train_prev, columns=avg_data.columns.to_numpy(),
fit_scores=self.fit_scores, data=avg_data.T.to_numpy(),
) metric=metric,
pp = [(mode, _pp0)] name=conf,
if stdev: train_prev=self.train_prev,
fs_dict = self.stdev_by_prevs(metric=metric, estimators=estimators) stdevs=st_data.T.to_numpy(),
_pp1 = plot.plot_delta( )
self.cprevs, elif mode == "diagonal":
f_dict, f_data = self.data(metric=metric + "_score", estimators=estimators)
metric=metric, ref: pd.Series = f_data.loc[:, "ref"]
name=conf, f_data.drop(columns=["ref"], inplace=True)
train_prev=self.train_prev, return plot.plot_diagonal(
fit_scores=self.fit_scores, reference=ref.to_numpy(),
stdevs=fs_dict, columns=f_data.columns.to_numpy(),
) data=f_data.T.to_numpy(),
pp.append((f"{mode}_stdev", _pp1)) metric=metric,
elif mode == "diagonal": name=conf,
f_dict = { train_prev=self.train_prev,
col: np.concatenate([vals[bp] for bp in self.prevs]) )
for col, vals in self.data( elif mode == "shift":
metric=metric + "_score", estimators=estimators shift_data = (
).items() self.shift_data(metric=metric, estimators=estimators)
} .groupby(level=0)
reference = f_dict["ref"] .mean()
f_dict = {k: v for k, v in f_dict.items() if k != "ref"} )
_pp0 = plot.plot_diagonal( shift_prevs = np.around(
reference, [(1.0 - p, p) for p in np.sort(shift_data.index.unique(0))],
f_dict, decimals=2,
metric=metric, )
name=conf, return plot.plot_shift(
train_prev=self.train_prev, shift_prevs=shift_prevs,
) columns=shift_data.columns.to_numpy(),
pp = [(mode, _pp0)] data=shift_data.T.to_numpy(),
metric=metric,
name=conf,
train_prev=self.train_prev,
)
elif mode == "shift": def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str:
s_prevs, s_dict = self.group_by_shift(
metric=metric, estimators=estimators
)
_pp0 = plot.plot_shift(
np.around([(1.0 - p, p) for p in s_prevs], decimals=2),
{
col: np.array([np.mean(vals[sp]) for sp in s_prevs])
for col, vals in s_dict.items()
},
metric=metric,
name=conf,
train_prev=self.train_prev,
fit_scores=self.fit_scores,
)
pp = [(mode, _pp0)]
pps.extend(pp)
return pps
def to_md(self, conf="default", metric="acc", estimators=None, stdev=False):
res = f"## {int(np.around(self.train_prev, decimals=2)[1]*100)}% positives\n" res = f"## {int(np.around(self.train_prev, decimals=2)[1]*100)}% positives\n"
res += fmt_line_md(f"train: {str(self.train_prev)}") res += fmt_line_md(f"train: {str(self.train_prev)}")
res += fmt_line_md(f"validation: {str(self.valid_prev)}") res += fmt_line_md(f"validation: {str(self.valid_prev)}")
for k, v in self.times.items(): for k, v in self.times.items():
res += fmt_line_md(f"{k}: {v:.3f}s") res += fmt_line_md(f"{k}: {v:.3f}s")
res += "\n" res += "\n"
res += ( res += self.table(metric=metric, estimators=estimators).to_html() + "\n\n"
self.get_dataframe(metric=metric, estimators=estimators).to_html() + "\n\n"
) plot_modes = np.array(["delta", "diagonal", "shift"], dtype="object")
plot_modes = ["delta", "diagonal", "shift"] if stdev:
for mode, op in self.get_plots( whd = np.where(plot_modes == "delta")[0]
modes=plot_modes, if len(whd) > 0:
metric=metric, plot_modes = np.insert(plot_modes, whd + 1, "delta_stdev")
estimators=estimators, for mode in plot_modes:
conf=conf, op = self.get_plots(
stdev=stdev, mode=mode,
): metric=metric,
estimators=estimators,
conf=conf,
stdev=stdev,
)
res += f"![plot_{mode}]({op.relative_to(env.OUT_DIR).as_posix()})\n" res += f"![plot_{mode}]({op.relative_to(env.OUT_DIR).as_posix()})\n"
return res return res
class DatasetReport: class DatasetReport:
def __init__(self, name): def __init__(self, name, crs=None):
self.name = name self.name = name
self._dict = None self.crs: List[CompReport] = [] if crs is None else crs
self.crs: List[CompReport] = []
@property def data(self, metric: str = None, estimators: str = None) -> pd.DataFrame:
def cprevs(self): def _cr_train_prev(cr: CompReport):
return np.around([(1.0 - p, p) for p in self.prevs], decimals=2) return cr.train_prev[1]
def _cr_data(cr: CompReport):
return cr.data(metric, estimators)
_crs_sorted = sorted(
[(_cr_train_prev(cr), _cr_data(cr)) for cr in self.crs],
key=lambda cr: len(cr[1].columns),
reverse=True,
)
_crs_train, _crs_data = zip(*_crs_sorted)
_data = pd.concat(_crs_data, axis=0, keys=_crs_train)
_data = _data.sort_index(axis=0, level=0)
return _data
def shift_data(self, metric: str = None, estimators: str = None) -> pd.DataFrame:
_shift_data: pd.DataFrame = pd.concat(
sorted(
[cr.shift_data(metric, estimators) for cr in self.crs],
key=lambda d: len(d.columns),
reverse=True,
),
axis=0,
)
shift_idx_0 = _shift_data.index.get_level_values(0)
shift_idx_1 = np.empty(shape=shift_idx_0.shape, dtype="<i4")
for _id in np.unique(shift_idx_0):
_wh = np.where(shift_idx_0 == _id)[0]
shift_idx_1[_wh] = np.arange(_wh.shape[0])
_shift_data.index = pd.MultiIndex.from_arrays([shift_idx_0, shift_idx_1])
_shift_data = _shift_data.sort_index(axis=0, level=0)
return _shift_data
def add(self, cr: CompReport): def add(self, cr: CompReport):
if cr is None: if cr is None:
@ -285,57 +278,11 @@ class DatasetReport:
self.crs.append(cr) self.crs.append(cr)
if self._dict is None: def __add__(self, cr: CompReport):
self.prevs = cr.prevs if cr is None:
self._dict = {
col: {bp: vals[bp] for bp in self.prevs}
for col, vals in cr.data().items()
}
self.s_prevs, self.s_dict = cr.group_by_shift()
self.fit_scores = {k: [score] for k, score in cr.fit_scores.items()}
return return
cr_dict = cr.data() return DatasetReport(self.name, crs=self.crs + [cr])
both_prevs = np.array([self.prevs, cr.prevs])
if not np.all(both_prevs == both_prevs[0, :]).all():
raise ValueError("Comp report has incompatible base prevalences")
for col, vals in cr_dict.items():
if col not in self._dict:
self._dict[col] = {}
for bp in self.prevs:
if bp not in self._dict[col]:
self._dict[col][bp] = []
self._dict[col][bp] = np.concatenate(
[self._dict[col][bp], cr_dict[col][bp]]
)
cr_s_prevs, cr_s_dict = cr.group_by_shift()
self.s_prevs = np.sort(np.unique(np.concatenate([self.s_prevs, cr_s_prevs])))
for col, vals in cr_s_dict.items():
if col not in self.s_dict:
self.s_dict[col] = {}
for sp in cr_s_prevs:
if sp not in self.s_dict[col]:
self.s_dict[col][sp] = []
self.s_dict[col][sp] = np.concatenate(
[self.s_dict[col][sp], cr_s_dict[col][sp]]
)
for sp in self.s_prevs:
for col, vals in self.s_dict.items():
if sp not in vals:
vals[sp] = []
for k, score in cr.fit_scores.items():
if k not in self.fit_scores:
self.fit_scores[k] = []
self.fit_scores[k].append(score)
def __add__(self, cr: CompReport):
self.add(cr)
return self
def __iadd__(self, cr: CompReport): def __iadd__(self, cr: CompReport):
self.add(cr) self.add(cr)
@ -346,70 +293,50 @@ class DatasetReport:
for cr in self.crs: for cr in self.crs:
res += f"{cr.to_md(conf, metric=metric, estimators=estimators, stdev=stdev)}\n\n" res += f"{cr.to_md(conf, metric=metric, estimators=estimators, stdev=stdev)}\n\n"
f_dict = { _data = self.data(metric=metric, estimators=estimators)
c1: v _shift_data = self.shift_data(metric=metric, estimators=estimators)
for ((c0, c1), v) in self._dict.items()
if c0 == metric and c1 in estimators avg_x_test = _data.groupby(level=1).mean()
} prevs_x_test = np.sort(avg_x_test.index.unique(0))
s_avg_dict = { stdev_x_test = _data.groupby(level=1).std() if stdev else None
col: np.array([np.mean(vals[sp]) for sp in self.s_prevs]) avg_x_test_tbl = _data.groupby(level=1).mean()
for col, vals in { avg_x_test_tbl.loc["avg", :] = _data.mean()
c1: v
for ((c0, c1), v) in self.s_dict.items() avg_x_shift = _shift_data.groupby(level=0).mean()
if c0 == metric and c1 in estimators prevs_x_shift = np.sort(avg_x_shift.index.unique(0))
}.items()
}
avg_dict = {
col: np.array([np.mean(vals[bp]) for bp in self.prevs])
for col, vals in f_dict.items()
}
if stdev:
stdev_dict = {
col: np.array([np.std(vals[bp]) for bp in self.prevs])
for col, vals in f_dict.items()
}
all_dict = {
col: [np.mean(np.concatenate(list(vals.values())))]
for col, vals in f_dict.items()
}
df = pd.DataFrame(
{col: np.append(avg_dict[col], val) for col, val in all_dict.items()},
index=self.prevs.tolist() + ["tot"],
columns=all_dict.keys(),
)
res += "## avg\n" res += "## avg\n"
res += df.to_html() + "\n\n" res += avg_x_test_tbl.to_html() + "\n\n"
delta_op = plot.plot_delta( delta_op = plot.plot_delta(
np.around([(1.0 - p, p) for p in self.prevs], decimals=2), base_prevs=np.around([(1.0 - p, p) for p in prevs_x_test], decimals=2),
avg_dict, columns=avg_x_test.columns.to_numpy(),
data=avg_x_test.T.to_numpy(),
metric=metric, metric=metric,
name=conf, name=conf,
train_prev=None, train_prev=None,
fit_scores={k: np.mean(vals) for k, vals in self.fit_scores.items()},
) )
res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n" res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n"
if stdev: if stdev:
delta_stdev_op = plot.plot_delta( delta_stdev_op = plot.plot_delta(
np.around([(1.0 - p, p) for p in self.prevs], decimals=2), base_prevs=np.around([(1.0 - p, p) for p in prevs_x_test], decimals=2),
avg_dict, columns=avg_x_test.columns.to_numpy(),
data=avg_x_test.T.to_numpy(),
metric=metric, metric=metric,
name=conf, name=conf,
train_prev=None, train_prev=None,
fit_scores={k: np.mean(vals) for k, vals in self.fit_scores.items()}, stdevs=stdev_x_test.T.to_numpy(),
stdevs=stdev_dict,
) )
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n" res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n"
shift_op = plot.plot_shift( shift_op = plot.plot_shift(
np.around([(1.0 - p, p) for p in self.s_prevs], decimals=2), shift_prevs=np.around([(1.0 - p, p) for p in prevs_x_shift], decimals=2),
s_avg_dict, columns=avg_x_shift.columns.to_numpy(),
data=avg_x_shift.T.to_numpy(),
metric=metric, metric=metric,
name=conf, name=conf,
train_prev=None, train_prev=None,
fit_scores={k: np.mean(vals) for k, vals in self.fit_scores.items()},
) )
res += f"![plot_shift]({shift_op.relative_to(env.OUT_DIR).as_posix()})\n" res += f"![plot_shift]({shift_op.relative_to(env.OUT_DIR).as_posix()})\n"
@ -417,3 +344,175 @@ class DatasetReport:
def __iter__(self): def __iter__(self):
return (cr for cr in self.crs) return (cr for cr in self.crs)
def __test():
df = None
print(f"{df is None = }")
if df is None:
bp = 0.75
idx = 0
d = {"a": 0.0, "b": 0.1}
df = pd.DataFrame(
d,
index=pd.MultiIndex.from_tuples([(bp, idx)]),
columns=d.keys(),
)
print(df)
print("-" * 100)
bp = 0.75
idx = len(df.loc[bp, :])
df.loc[(bp, idx), :] = {"a": 0.2, "b": 0.3}
print(df)
print("-" * 100)
bp = 0.90
idx = len(df.loc[bp, :]) if bp in df.index else 0
df.loc[(bp, idx), :] = {"a": 0.2, "b": 0.3}
print(df)
print("-" * 100)
bp = 0.90
idx = len(df.loc[bp, :]) if bp in df.index else 0
d = {"a": 0.2, "v": 0.3, "e": 0.4}
notin = np.setdiff1d(list(d.keys()), df.columns)
df.loc[:, notin] = np.nan
df.loc[(bp, idx), :] = d
print(df)
print("-" * 100)
bp = 0.90
idx = len(df.loc[bp, :]) if bp in df.index else 0
d = {"a": 0.3, "v": 0.4, "e": 0.5}
notin = np.setdiff1d(list(d.keys()), df.columns)
print(f"{notin = }")
df.loc[:, notin] = np.nan
df.loc[(bp, idx), :] = d
print(df)
print("-" * 100)
print(f"{np.sort(np.unique(df.index.get_level_values(0))) = }")
print("-" * 100)
print(f"{df.loc[(0.75, ),:] = }\n")
print(f"{df.loc[(slice(None), 1),:] = }")
print("-" * 100)
print(f"{(0.75, ) in df.index = }")
print(f"{(0.7, ) in df.index = }")
print("-" * 100)
df1 = pd.DataFrame(
{
"a": np.linspace(0.0, 1.0, 6),
"b": np.linspace(1.0, 2.0, 6),
"e": np.linspace(2.0, 3.0, 6),
"v": np.linspace(0.0, 1.0, 6),
},
index=pd.MultiIndex.from_product([[0.75, 0.9], [0, 1, 2]]),
columns=["a", "b", "e", "v"],
)
df2 = (
pd.concat([df, df1], keys=["a", "b"], axis=1)
.swaplevel(0, 1, axis=1)
.sort_index(axis=1, level=0)
)
df3 = pd.concat([df1, df], keys=["b", "a"], axis=1)
print(df)
print(df1)
print(df2)
print(df3)
df = df3
print("-" * 100)
print(df.loc[:, ("b", ["e", "v"])])
print(df.loc[:, (slice(None), ["e", "v"])])
print(df.loc[:, ("b", slice(None))])
print(df.loc[:, ("b", slice(None))].droplevel(level=0, axis=1))
print(df.loc[:, (slice(None), ["e", "v"])].droplevel(level=0, axis=1))
print(len(df.loc[:, ("b", slice(None))].columns.unique(0)))
print("-" * 100)
idx_0 = np.around(np.abs(df.index.get_level_values(0).to_numpy() - 0.8), decimals=2)
midx = pd.MultiIndex.from_arrays([idx_0, df.index.get_level_values(1)])
print(midx)
dfs = df.copy()
dfs.index = midx
print(df)
print(dfs)
print("-" * 100)
df.loc[(0.85, 0), :] = np.linspace(0, 1, 8)
df.loc[(0.85, 1), :] = np.linspace(0, 1, 8)
df.loc[(0.85, 2), :] = np.linspace(0, 1, 8)
idx_0 = np.around(np.abs(df.index.get_level_values(0).to_numpy() - 0.8), decimals=2)
print(np.where(idx_0 == 0.05))
idx_1 = np.empty(shape=idx_0.shape, dtype="<i4")
print(idx_1)
for _id in np.unique(idx_0):
wh = np.where(idx_0 == _id)[0]
idx_1[wh] = np.arange(wh.shape[0])
midx = pd.MultiIndex.from_arrays([idx_0, idx_1])
dfs = df.copy()
dfs.index = midx
dfs.sort_index(level=0, axis=0, inplace=True)
print(df)
print(dfs)
print("-" * 100)
print(np.sort(dfs.index.unique(0)))
print("-" * 100)
print(df.groupby(level=0).mean())
print(dfs.groupby(level=0).mean())
print("-" * 100)
s = df.mean(axis=0)
dfa = df.groupby(level=0).mean()
dfa.loc["avg", :] = s
print(dfa)
print("-" * 100)
print(df)
dfn = df.loc[:, (slice(None), slice(None))]
print(dfn)
print(f"{df is dfn = }")
print("-" * 100)
a = np.array(["abc", "bcd", "cde", "bcd"], dtype="object")
print(a)
whb = np.where(a == "bcd")[0]
if len(whb) > 0:
a = np.insert(a, whb + 1, "pippo")
print(a)
print("-" * 100)
dff: pd.DataFrame = df.loc[:, ("a",)]
print(dff.to_dict(orient="list"))
dff = dff.drop(columns=["v"])
print(dff)
s: pd.Series = dff.loc[:, "e"]
print(s)
print(s.to_numpy())
print(type(s.to_numpy()))
print("-" * 100)
df3 = pd.concat([df, df], axis=0, keys=[0.5, 0.3]).sort_index(axis=0, level=0)
print(df3)
df3n = pd.concat([df, df], axis=0).sort_index(axis=0, level=0)
print(df3n)
df = df3
print("-" * 100)
print(df.groupby(level=1).mean(), df.groupby(level=1).count())
print("-" * 100)
print(df)
for ls in df.T.to_numpy():
print(ls)
print("-" * 100)
if __name__ == "__main__":
__test()

View File

@ -1,10 +1,11 @@
import time import time
from traceback import print_exception as traceback
import quapy as qp import quapy as qp
from quapy.protocol import APP from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from quacc.logging import SubLogger from quacc.logger import SubLogger
def estimate_worker(_estimate, train, validation, test, _env=None, q=None): def estimate_worker(_estimate, train, validation, test, _env=None, q=None):
@ -26,6 +27,7 @@ def estimate_worker(_estimate, train, validation, test, _env=None, q=None):
result = _estimate(model, validation, protocol) result = _estimate(model, validation, protocol)
except Exception as e: except Exception as e:
log.warning(f"Method {_estimate.__name__} failed. Exception: {e}") log.warning(f"Method {_estimate.__name__} failed. Exception: {e}")
# traceback(e)
return { return {
"name": _estimate.__name__, "name": _estimate.__name__,
"result": None, "result": None,

View File

@ -47,7 +47,7 @@ class Logger:
qh.setLevel(logging.DEBUG) qh.setLevel(logging.DEBUG)
qh.setFormatter( qh.setFormatter(
logging.Formatter( logging.Formatter(
fmt="%(asctime)s| %(levelname)s: %(message)s", fmt="%(asctime)s| %(levelname)-8s %(message)s",
datefmt="%d/%m/%y %H:%M:%S", datefmt="%d/%m/%y %H:%M:%S",
) )
) )
@ -102,7 +102,7 @@ class SubLogger:
rh.setLevel(logging.DEBUG) rh.setLevel(logging.DEBUG)
rh.setFormatter( rh.setFormatter(
logging.Formatter( logging.Formatter(
fmt="%(asctime)s| %(levelname)s: %(message)s", fmt="%(asctime)s| %(levelname)-12s\t%(message)s",
datefmt="%d/%m/%y %H:%M:%S", datefmt="%d/%m/%y %H:%M:%S",
) )
) )

View File

@ -1,14 +1,12 @@
import traceback
from sys import platform from sys import platform
from traceback import print_exception as traceback
import quacc.evaluation.comp as comp import quacc.evaluation.comp as comp
from quacc.dataset import Dataset from quacc.dataset import Dataset
from quacc.environment import env from quacc.environment import env
from quacc.logging import Logger from quacc.logger import Logger
from quacc.utils import create_dataser_dir from quacc.utils import create_dataser_dir
log = Logger.logger()
def toast(): def toast():
if platform == "win32": if platform == "win32":
@ -18,37 +16,44 @@ def toast():
def estimate_comparison(): def estimate_comparison():
log = Logger.logger()
for conf in env.get_confs(): for conf in env.get_confs():
create_dataser_dir(conf, update=env.DATASET_DIR_UPDATE)
dataset = Dataset( dataset = Dataset(
env.DATASET_NAME, env.DATASET_NAME,
target=env.DATASET_TARGET, target=env.DATASET_TARGET,
n_prevalences=env.DATASET_N_PREVS, n_prevalences=env.DATASET_N_PREVS,
prevs=env.DATASET_PREVS, prevs=env.DATASET_PREVS,
) )
create_dataser_dir(dataset.name, update=env.DATASET_DIR_UPDATE)
try: try:
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS) dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
for plot_conf in env.get_plot_confs():
for m in env.METRICS:
output_path = env.OUT_DIR / f"{plot_conf}_{m}.md"
with open(output_path, "w") as f:
f.write(
dr.to_md(
conf=plot_conf,
metric=m,
estimators=env.PLOT_ESTIMATORS,
stdev=env.PLOT_STDEV,
)
)
except Exception as e: except Exception as e:
log.error(f"Configuration {conf} failed. Exception: {e}") log.error(f"Evaluation over {dataset.name} failed. Exception: {e}")
traceback(e) traceback(e)
for plot_conf in env.get_plot_confs():
for m in env.METRICS:
output_path = env.OUT_DIR / f"{plot_conf}_{m}.md"
try:
_repr = dr.to_md(
conf=plot_conf,
metric=m,
estimators=env.PLOT_ESTIMATORS,
stdev=env.PLOT_STDEV,
)
with open(output_path, "w") as f:
f.write(_repr)
except Exception as e:
log.error(
f"Failed while saving configuration {plot_conf} of {dataset.name}. Exception: {e}"
)
traceback(e)
# print(df.to_latex(float_format="{:.4f}".format)) # print(df.to_latex(float_format="{:.4f}".format))
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format)) # print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
def main(): def main():
log = Logger.logger()
try: try:
estimate_comparison() estimate_comparison()
except Exception as e: except Exception as e:

View File

@ -19,7 +19,8 @@ def _get_markers(n: int):
def plot_delta( def plot_delta(
base_prevs, base_prevs,
dict_vals, columns,
data,
*, *,
stdevs=None, stdevs=None,
pos_class=1, pos_class=1,
@ -40,14 +41,14 @@ def plot_delta(
ax.set_aspect("auto") ax.set_aspect("auto")
ax.grid() ax.grid()
NUM_COLORS = len(dict_vals) NUM_COLORS = len(data)
cm = plt.get_cmap("tab10") cm = plt.get_cmap("tab10")
if NUM_COLORS > 10: if NUM_COLORS > 10:
cm = plt.get_cmap("tab20") cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
base_prevs = base_prevs[:, pos_class] base_prevs = base_prevs[:, pos_class]
for (method, deltas), _cy in zip(dict_vals.items(), cy): for method, deltas, _cy in zip(columns, data, cy):
ax.plot( ax.plot(
base_prevs, base_prevs,
deltas, deltas,
@ -59,11 +60,17 @@ def plot_delta(
zorder=2, zorder=2,
) )
if stdevs is not None: if stdevs is not None:
stdev = stdevs[method] _col_idx = np.where(columns == method)[0]
stdev = stdevs[_col_idx].flatten()
nn_idx = np.intersect1d(
np.where(deltas != np.nan)[0],
np.where(stdev != np.nan)[0],
)
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
ax.fill_between( ax.fill_between(
base_prevs, _bps,
deltas - stdev, _ds - _st,
deltas + stdev, _ds + _st,
color=_cy["color"], color=_cy["color"],
alpha=0.25, alpha=0.25,
) )
@ -88,7 +95,8 @@ def plot_delta(
def plot_diagonal( def plot_diagonal(
reference, reference,
dict_vals, columns,
data,
*, *,
pos_class=1, pos_class=1,
metric="acc", metric="acc",
@ -107,7 +115,7 @@ def plot_diagonal(
ax.grid() ax.grid()
ax.set_aspect("equal") ax.set_aspect("equal")
NUM_COLORS = len(dict_vals) NUM_COLORS = len(data)
cm = plt.get_cmap("tab10") cm = plt.get_cmap("tab10")
if NUM_COLORS > 10: if NUM_COLORS > 10:
cm = plt.get_cmap("tab20") cm = plt.get_cmap("tab20")
@ -120,7 +128,7 @@ def plot_diagonal(
x_ticks = np.unique(reference) x_ticks = np.unique(reference)
x_ticks.sort() x_ticks.sort()
for (_, deltas), _cy in zip(dict_vals.items(), cy): for deltas, _cy in zip(data, cy):
ax.plot( ax.plot(
reference, reference,
deltas, deltas,
@ -137,7 +145,7 @@ def plot_diagonal(
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)]) _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
ax.set(xlim=tuple(_lims), ylim=tuple(_lims)) ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
for (method, deltas), _cy in zip(dict_vals.items(), cy): for method, deltas, _cy in zip(columns, data, cy):
slope, interc = np.polyfit(reference, deltas, 1) slope, interc = np.polyfit(reference, deltas, 1)
y_lr = np.array([slope * x + interc for x in _lims]) y_lr = np.array([slope * x + interc for x in _lims])
ax.plot( ax.plot(
@ -171,7 +179,8 @@ def plot_diagonal(
def plot_shift( def plot_shift(
shift_prevs, shift_prevs,
shift_dict, columns,
data,
*, *,
pos_class=1, pos_class=1,
metric="acc", metric="acc",
@ -190,14 +199,14 @@ def plot_shift(
ax.set_aspect("auto") ax.set_aspect("auto")
ax.grid() ax.grid()
NUM_COLORS = len(shift_dict) NUM_COLORS = len(data)
cm = plt.get_cmap("tab10") cm = plt.get_cmap("tab10")
if NUM_COLORS > 10: if NUM_COLORS > 10:
cm = plt.get_cmap("tab20") cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
shift_prevs = shift_prevs[:, pos_class] shift_prevs = shift_prevs[:, pos_class]
for (method, shifts), _cy in zip(shift_dict.items(), cy): for method, shifts, _cy in zip(columns, data, cy):
ax.plot( ax.plot(
shift_prevs, shift_prevs,
shifts, shifts,