diff --git a/.vscode/launch.json b/.vscode/launch.json index 91cb6a4..f6c8bea 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,6 +5,7 @@ "version": "0.2.0", "configurations": [ + { "name": "main", "type": "python", diff --git a/conf.yaml b/conf.yaml index c109fad..44920d0 100644 --- a/conf.yaml +++ b/conf.yaml @@ -2,17 +2,17 @@ debug_conf: &debug_conf global: METRICS: - acc - DATASET_N_PREVS: 1 + DATASET_N_PREVS: 5 + DATASET_PREVS: + - 0.5 datasets: - - DATASET_NAME: rcv1 - DATASET_TARGET: CCAT - DATASET_NAME: imdb plot_confs: debug: PLOT_ESTIMATORS: - # - mul_sld_bcts + - mul_sld_bcts - mul_sld - ref - atc_mc @@ -23,11 +23,15 @@ test_conf: &test_conf METRICS: - acc - f1 - DATASET_N_PREVS: 3 + DATASET_N_PREVS: 2 + DATASET_PREVS: + - 0.5 + - 0.1 datasets: - - DATASET_NAME: rcv1 - DATASET_TARGET: CCAT + # - DATASET_NAME: rcv1 + # DATASET_TARGET: CCAT + - DATASET_NAME: imdb plot_confs: best_vs_atc: @@ -100,4 +104,4 @@ main_conf: &main_conf - atc_ne - doc_feat -exec: *main_conf \ No newline at end of file +exec: *test_conf \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index c7dcbea..c8da77f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -443,6 +443,16 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] +[[package]] +name = "logging" +version = "0.4.9.6" +description = "A logging module for Python" +optional = false +python-versions = "*" +files = [ + {file = "logging-0.4.9.6.tar.gz", hash = "sha256:26f6b50773f085042d301085bd1bf5d9f3735704db9f37c1ce6d8b85c38f2417"}, +] + [[package]] name = "markupsafe" version = "2.1.3" @@ -1261,4 +1271,4 @@ test = ["pytest", "pytest-cov"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "c98b7510ac055b667340b52e1b0b0777370e68d325d3149cb1fef42b6f1ec50a" +content-hash = "54d9922f6d48a46f554a6b350ce09d668a88755efb1fbf295f8f8a0a411bdef2" diff --git a/pyproject.toml b/pyproject.toml index 336e224..ee62e17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ quapy = "^0.1.7" pandas = "^2.0.3" jinja2 = "^3.1.2" pyyaml = "^6.0.1" +logging = "^0.4.9.6" [tool.poetry.scripts] main = "quacc.main:main" diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py index 8ec32d4..9a5fc5d 100644 --- a/quacc/evaluation/baseline.py +++ b/quacc/evaluation/baseline.py @@ -86,6 +86,7 @@ def atc_mc( protocol: AbstractStochasticSeededProtocol, predict_method="predict_proba", ): + """garg""" c_model_predict = getattr(c_model, predict_method) ## Load ID validation data probs and labels @@ -124,6 +125,7 @@ def atc_ne( protocol: AbstractStochasticSeededProtocol, predict_method="predict_proba", ): + """garg""" c_model_predict = getattr(c_model, predict_method) ## Load ID validation data probs and labels diff --git a/quacc/evaluation/comp.py b/quacc/evaluation/comp.py index c0a5eba..2f45343 100644 --- a/quacc/evaluation/comp.py +++ b/quacc/evaluation/comp.py @@ -1,6 +1,6 @@ import multiprocessing import time -import traceback +from traceback import print_exception as traceback from typing import List import pandas as pd @@ -11,11 +11,10 @@ from quacc.environment import env from quacc.evaluation import baseline, method from quacc.evaluation.report import CompReport, DatasetReport, EvaluationReport from quacc.evaluation.worker import estimate_worker -from quacc.logging import Logger +from quacc.logger import Logger pd.set_option("display.float_format", "{:.4f}".format) qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE -log = Logger.logger() class CompEstimator: @@ -43,6 +42,7 @@ CE = CompEstimator def evaluate_comparison( dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"] ) -> EvaluationReport: + log = Logger.logger() # with multiprocessing.Pool(1) as pool: with multiprocessing.Pool(len(estimators)) as pool: dr = DatasetReport(dataset.name) diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index 50ff5ad..18697d3 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -9,68 +9,48 @@ from quacc.environment import env from quacc.utils import fmt_line_md +def _get_metric(metric: str): + return slice(None) if metric is None else metric + + +def _get_estimators(estimators: List[str], cols: np.ndarray): + return slice(None) if estimators is None else cols[np.in1d(cols, estimators)] + + class EvaluationReport: def __init__(self, name=None): - self._prevs = [] - self._dict = {} + self.data: pd.DataFrame = None self.fit_score = None self.name = name if name is not None else "default" def append_row(self, basep: np.ndarray | Tuple, **row): bp = basep[1] - self._prevs.append(bp) - for k, v in row.items(): - if k not in self._dict: - self._dict[k] = {} - if bp not in self._dict[k]: - self._dict[k][bp] = [] - self._dict[k][bp] = np.append(self._dict[k][bp], [v]) + _keys, _values = zip(*row.items()) + # _keys = list(row.keys()) + # _values = list(row.values()) + + if self.data is None: + _idx = 0 + self.data = pd.DataFrame( + {k: [v] for k, v in row.items()}, + index=pd.MultiIndex.from_tuples([(bp, _idx)]), + columns=_keys, + ) + return + + _idx = len(self.data.loc[(bp,), :]) if (bp,) in self.data.index else 0 + not_in_data = np.setdiff1d(list(row.keys()), self.data.columns.unique(0)) + self.data.loc[:, not_in_data] = np.nan + self.data.loc[(bp, _idx), :] = row + return @property - def columns(self): - return self._dict.keys() + def columns(self) -> np.ndarray: + return self.data.columns.unique(0) @property def prevs(self): - return np.sort(np.unique([list(self._dict[_k].keys()) for _k in self._dict])) - - # def group_by_prevs(self, metric: str = None, estimators: List[str] = None): - # if self._g_dict is None: - # self._g_prevs = [] - # self._g_dict = {k: [] for k in self._dict.keys()} - - # for col, vals in self._dict.items(): - # col_grouped = {} - # for bp, v in zip(self._prevs, vals): - # if bp not in col_grouped: - # col_grouped[bp] = [] - # col_grouped[bp].append(v) - - # self._g_dict[col] = [ - # vs - # for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1]) - # ] - - # self._g_prevs = sorted( - # [(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()], - # key=lambda bp: bp[1], - # ) - - # fg_dict = _filter_dict(self._g_dict, metric, estimators) - # return self._g_prevs, fg_dict - - # def merge(self, other): - # if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)): - # raise ValueError("other has not same base prevalences of self") - - # inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys())) - # if len(inters_keys) > 0: - # raise ValueError(f"self and other have matching keys {str(inters_keys)}.") - - # report = EvaluationReport() - # report._prevs = self._prevs - # report._dict = self._dict | other._dict - # return report + return np.sort(self.data.index.unique(0)) class CompReport: @@ -82,22 +62,16 @@ class CompReport: valid_prev=None, times=None, ): - all_prevs = np.array([er.prevs for er in reports]) - if not np.all(all_prevs == all_prevs[0, :], axis=0).all(): - raise ValueError( - "Not all evaluation reports have the same base prevalences" + self._data = ( + pd.concat( + [er.data for er in reports], + keys=[er.name for er in reports], + axis=1, ) - uq_names, name_c = np.unique([er.name for er in reports], return_counts=True) - if np.sum(name_c) > uq_names.shape[0]: - _matching = uq_names[[c > 1 for c in name_c]] - raise ValueError( - f"Evaluation reports have matching names: {_matching.tolist()}." - ) - - all_dicts = [{(k, er.name): v for k, v in er._dict.items()} for er in reports] - self._dict = {} - for d in all_dicts: - self._dict = self._dict | d + .swaplevel(0, 1, axis=1) + .sort_index(axis=1, level=0, sort_remaining=False) + .sort_index(axis=0, level=0) + ) self.fit_scores = { er.name: er.fit_score for er in reports if er.fit_score is not None @@ -107,177 +81,195 @@ class CompReport: self.times = times @property - def prevs(self): - return np.sort(np.unique([list(self._dict[_k].keys()) for _k in self._dict])) + def prevs(self) -> np.ndarray: + return np.sort(self._data.index.unique(0)) @property - def cprevs(self): + def np_prevs(self) -> np.ndarray: return np.around([(1.0 - p, p) for p in self.prevs], decimals=2) - def data(self, metric: str = None, estimators: List[str] = None) -> dict: - f_dict = self._dict.copy() - if metric is not None: - f_dict = {(c0, c1): ls for ((c0, c1), ls) in f_dict.items() if c0 == metric} - if estimators is not None: - f_dict = { - (c0, c1): ls for ((c0, c1), ls) in f_dict.items() if c1 in estimators - } - if (metric, estimators) != (None, None): - f_dict = {c1: ls for ((c0, c1), ls) in f_dict.items()} + def data(self, metric: str = None, estimators: List[str] = None) -> pd.DataFrame: + _metric = _get_metric(metric) + _estimators = _get_estimators(estimators, self._data.columns.unique(1)) + f_data: pd.DataFrame = self._data.copy().loc[:, (_metric, _estimators)] - return f_dict + if len(f_data.columns.unique(0)) == 1: + f_data = f_data.droplevel(level=0, axis=1) - def group_by_shift(self, metric: str = None, estimators: List[str] = None): - f_dict = self.data(metric=metric, estimators=estimators) - shift_prevs = np.around( - np.absolute(self.prevs - self.train_prev[1]), decimals=2 + return f_data + + def shift_data( + self, metric: str = None, estimators: List[str] = None + ) -> pd.DataFrame: + shift_idx_0 = np.around( + np.abs( + self._data.index.get_level_values(0).to_numpy() - self.train_prev[1] + ), + decimals=2, ) - shift_dict = {col: {sp: [] for sp in shift_prevs} for col in f_dict.keys()} - for col, vals in f_dict.items(): - for sp, bp in zip(shift_prevs, self.prevs): - shift_dict[col][sp] = np.concatenate( - [shift_dict[col][sp], f_dict[col][bp]] - ) - return np.sort(np.unique(shift_prevs)), shift_dict + shift_idx_1 = np.empty(shape=shift_idx_0.shape, dtype=" pd.DataFrame: f_dict = self.data(metric=metric, estimators=estimators) - return { - col: np.array([np.mean(vals[bp]) for bp in self.prevs]) - for col, vals in f_dict.items() - } + return f_dict.groupby(level=0).mean() - def stdev_by_prevs(self, metric: str = None, estimators: List[str] = None): + def stdev_by_prevs( + self, metric: str = None, estimators: List[str] = None + ) -> pd.DataFrame: f_dict = self.data(metric=metric, estimators=estimators) - return { - col: np.array([np.std(vals[bp]) for bp in self.prevs]) - for col, vals in f_dict.items() - } + return f_dict.groupby(level=0).std() - def avg_all(self, metric: str = None, estimators: List[str] = None): - f_dict = self.data(metric=metric, estimators=estimators) - return { - col: [np.mean(np.concatenate(list(vals.values())))] - for col, vals in f_dict.items() - } - - def get_dataframe(self, metric="acc", estimators=None): - avg_dict = self.avg_by_prevs(metric=metric, estimators=estimators) - all_dict = self.avg_all(metric=metric, estimators=estimators) - for col in avg_dict.keys(): - avg_dict[col] = np.append(avg_dict[col], all_dict[col]) - return pd.DataFrame( - avg_dict, - index=self.prevs.tolist() + ["tot"], - columns=avg_dict.keys(), - ) + def table(self, metric: str = None, estimators: List[str] = None) -> pd.DataFrame: + f_data = self.data(metric=metric, estimators=estimators) + avg_p = f_data.groupby(level=0).mean() + avg_p.loc["avg", :] = f_data.mean() + return avg_p def get_plots( - self, - modes=["delta", "diagonal", "shift"], - metric="acc", - estimators=None, - conf="default", - stdev=False, - ) -> Path: - pps = [] - for mode in modes: - pp = [] - if mode == "delta": - f_dict = self.avg_by_prevs(metric=metric, estimators=estimators) - _pp0 = plot.plot_delta( - self.cprevs, - f_dict, - metric=metric, - name=conf, - train_prev=self.train_prev, - fit_scores=self.fit_scores, - ) - pp = [(mode, _pp0)] - if stdev: - fs_dict = self.stdev_by_prevs(metric=metric, estimators=estimators) - _pp1 = plot.plot_delta( - self.cprevs, - f_dict, - metric=metric, - name=conf, - train_prev=self.train_prev, - fit_scores=self.fit_scores, - stdevs=fs_dict, - ) - pp.append((f"{mode}_stdev", _pp1)) - elif mode == "diagonal": - f_dict = { - col: np.concatenate([vals[bp] for bp in self.prevs]) - for col, vals in self.data( - metric=metric + "_score", estimators=estimators - ).items() - } - reference = f_dict["ref"] - f_dict = {k: v for k, v in f_dict.items() if k != "ref"} - _pp0 = plot.plot_diagonal( - reference, - f_dict, - metric=metric, - name=conf, - train_prev=self.train_prev, - ) - pp = [(mode, _pp0)] + self, mode="delta", metric="acc", estimators=None, conf="default", stdev=False + ) -> List[Tuple[str, Path]]: + if mode == "delta": + avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) + return plot.plot_delta( + base_prevs=self.np_prevs, + columns=avg_data.columns.to_numpy(), + data=avg_data.T.to_numpy(), + metric=metric, + name=conf, + train_prev=self.train_prev, + ) + elif mode == "delta_stdev": + avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) + st_data = self.stdev_by_prevs(metric=metric, estimators=estimators) + return plot.plot_delta( + base_prevs=self.np_prevs, + columns=avg_data.columns.to_numpy(), + data=avg_data.T.to_numpy(), + metric=metric, + name=conf, + train_prev=self.train_prev, + stdevs=st_data.T.to_numpy(), + ) + elif mode == "diagonal": + f_data = self.data(metric=metric + "_score", estimators=estimators) + ref: pd.Series = f_data.loc[:, "ref"] + f_data.drop(columns=["ref"], inplace=True) + return plot.plot_diagonal( + reference=ref.to_numpy(), + columns=f_data.columns.to_numpy(), + data=f_data.T.to_numpy(), + metric=metric, + name=conf, + train_prev=self.train_prev, + ) + elif mode == "shift": + shift_data = ( + self.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + shift_prevs = np.around( + [(1.0 - p, p) for p in np.sort(shift_data.index.unique(0))], + decimals=2, + ) + return plot.plot_shift( + shift_prevs=shift_prevs, + columns=shift_data.columns.to_numpy(), + data=shift_data.T.to_numpy(), + metric=metric, + name=conf, + train_prev=self.train_prev, + ) - elif mode == "shift": - s_prevs, s_dict = self.group_by_shift( - metric=metric, estimators=estimators - ) - _pp0 = plot.plot_shift( - np.around([(1.0 - p, p) for p in s_prevs], decimals=2), - { - col: np.array([np.mean(vals[sp]) for sp in s_prevs]) - for col, vals in s_dict.items() - }, - metric=metric, - name=conf, - train_prev=self.train_prev, - fit_scores=self.fit_scores, - ) - pp = [(mode, _pp0)] - - pps.extend(pp) - - return pps - - def to_md(self, conf="default", metric="acc", estimators=None, stdev=False): + def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str: res = f"## {int(np.around(self.train_prev, decimals=2)[1]*100)}% positives\n" res += fmt_line_md(f"train: {str(self.train_prev)}") res += fmt_line_md(f"validation: {str(self.valid_prev)}") for k, v in self.times.items(): res += fmt_line_md(f"{k}: {v:.3f}s") res += "\n" - res += ( - self.get_dataframe(metric=metric, estimators=estimators).to_html() + "\n\n" - ) - plot_modes = ["delta", "diagonal", "shift"] - for mode, op in self.get_plots( - modes=plot_modes, - metric=metric, - estimators=estimators, - conf=conf, - stdev=stdev, - ): + res += self.table(metric=metric, estimators=estimators).to_html() + "\n\n" + + plot_modes = np.array(["delta", "diagonal", "shift"], dtype="object") + whd = np.where(plot_modes == "delta")[0] + if len(whd) > 0: + plot_modes = np.insert(plot_modes, whd + 1, "delta_stdev") + for mode in plot_modes: + op = self.get_plots( + mode=mode, + metric=metric, + estimators=estimators, + conf=conf, + stdev=stdev, + ) res += f"![plot_{mode}]({op.relative_to(env.OUT_DIR).as_posix()})\n" return res class DatasetReport: - def __init__(self, name): + def __init__(self, name, crs=None): self.name = name - self._dict = None - self.crs: List[CompReport] = [] + self.crs: List[CompReport] = [] if crs is None else crs - @property - def cprevs(self): - return np.around([(1.0 - p, p) for p in self.prevs], decimals=2) + def data(self, metric: str = None, estimators: str = None) -> pd.DataFrame: + def _cr_train_prev(cr: CompReport): + return cr.train_prev[1] + + def _cr_data(cr: CompReport): + return cr.data(metric, estimators) + + _crs_sorted = sorted( + [(_cr_train_prev(cr), _cr_data(cr)) for cr in self.crs], + key=lambda cr: len(cr[1].columns), + reverse=True, + ) + _crs_train, _crs_data = zip(*_crs_sorted) + + _data = pd.concat(_crs_data, axis=0, keys=_crs_train) + _data = _data.sort_index(axis=0, level=0) + return _data + + def shift_data(self, metric: str = None, estimators: str = None) -> pd.DataFrame: + _shift_data: pd.DataFrame = pd.concat( + sorted( + [cr.shift_data(metric, estimators) for cr in self.crs], + key=lambda d: len(d.columns), + reverse=True, + ), + axis=0, + ) + + shift_idx_0 = _shift_data.index.get_level_values(0) + + shift_idx_1 = np.empty(shape=shift_idx_0.shape, dtype=" 0: + a = np.insert(a, whb + 1, "pippo") + print(a) + print("-" * 100) + + dff: pd.DataFrame = df.loc[:, ("a",)] + print(dff.to_dict(orient="list")) + dff = dff.drop(columns=["v"]) + print(dff) + s: pd.Series = dff.loc[:, "e"] + print(s) + print(s.to_numpy()) + print(type(s.to_numpy())) + print("-" * 100) + + df3 = pd.concat([df, df], axis=0, keys=[0.5, 0.3]).sort_index(axis=0, level=0) + print(df3) + df3n = pd.concat([df, df], axis=0).sort_index(axis=0, level=0) + print(df3n) + df = df3 + print("-" * 100) + + print(df.groupby(level=1).mean(), df.groupby(level=1).count()) + print("-" * 100) + + print(df) + for ls in df.T.to_numpy(): + print(ls) + print("-" * 100) + + +if __name__ == "__main__": + __test() diff --git a/quacc/evaluation/worker.py b/quacc/evaluation/worker.py index 0ab75e2..1a96a5f 100644 --- a/quacc/evaluation/worker.py +++ b/quacc/evaluation/worker.py @@ -1,10 +1,11 @@ import time +from traceback import print_exception as traceback import quapy as qp from quapy.protocol import APP from sklearn.linear_model import LogisticRegression -from quacc.logging import SubLogger +from quacc.logger import SubLogger def estimate_worker(_estimate, train, validation, test, _env=None, q=None): @@ -26,6 +27,7 @@ def estimate_worker(_estimate, train, validation, test, _env=None, q=None): result = _estimate(model, validation, protocol) except Exception as e: log.warning(f"Method {_estimate.__name__} failed. Exception: {e}") + # traceback(e) return { "name": _estimate.__name__, "result": None, diff --git a/quacc/logging.py b/quacc/logger.py similarity index 97% rename from quacc/logging.py rename to quacc/logger.py index efa41af..c4cced8 100644 --- a/quacc/logging.py +++ b/quacc/logger.py @@ -102,7 +102,7 @@ class SubLogger: rh.setLevel(logging.DEBUG) rh.setFormatter( logging.Formatter( - fmt="%(asctime)s| %(levelname)s: %(message)s", + fmt="%(asctime)s| %(levelname)s:\t%(message)s", datefmt="%d/%m/%y %H:%M:%S", ) ) diff --git a/quacc/main.py b/quacc/main.py index 8a46e2a..b69c5d1 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -1,14 +1,12 @@ -import traceback from sys import platform +from traceback import print_exception as traceback import quacc.evaluation.comp as comp from quacc.dataset import Dataset from quacc.environment import env -from quacc.logging import Logger +from quacc.logger import Logger from quacc.utils import create_dataser_dir -log = Logger.logger() - def toast(): if platform == "win32": @@ -18,6 +16,7 @@ def toast(): def estimate_comparison(): + log = Logger.logger() for conf in env.get_confs(): create_dataser_dir(conf, update=env.DATASET_DIR_UPDATE) dataset = Dataset( @@ -49,6 +48,7 @@ def estimate_comparison(): def main(): + log = Logger.logger() try: estimate_comparison() except Exception as e: diff --git a/quacc/plot.py b/quacc/plot.py index 2b65c36..15876d2 100644 --- a/quacc/plot.py +++ b/quacc/plot.py @@ -19,7 +19,8 @@ def _get_markers(n: int): def plot_delta( base_prevs, - dict_vals, + columns, + data, *, stdevs=None, pos_class=1, @@ -40,14 +41,14 @@ def plot_delta( ax.set_aspect("auto") ax.grid() - NUM_COLORS = len(dict_vals) + NUM_COLORS = len(data) cm = plt.get_cmap("tab10") if NUM_COLORS > 10: cm = plt.get_cmap("tab20") cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) base_prevs = base_prevs[:, pos_class] - for (method, deltas), _cy in zip(dict_vals.items(), cy): + for method, deltas, _cy in zip(columns, data, cy): ax.plot( base_prevs, deltas, @@ -59,11 +60,17 @@ def plot_delta( zorder=2, ) if stdevs is not None: - stdev = stdevs[method] + _col_idx = np.where(columns == method)[0] + stdev = stdevs[_col_idx].flatten() + nn_idx = np.intersect1d( + np.where(deltas != np.nan)[0], + np.where(stdev != np.nan)[0], + ) + _bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx] ax.fill_between( - base_prevs, - deltas - stdev, - deltas + stdev, + _bps, + _ds - _st, + _ds + _st, color=_cy["color"], alpha=0.25, ) @@ -88,7 +95,8 @@ def plot_delta( def plot_diagonal( reference, - dict_vals, + columns, + data, *, pos_class=1, metric="acc", @@ -107,7 +115,7 @@ def plot_diagonal( ax.grid() ax.set_aspect("equal") - NUM_COLORS = len(dict_vals) + NUM_COLORS = len(data) cm = plt.get_cmap("tab10") if NUM_COLORS > 10: cm = plt.get_cmap("tab20") @@ -120,7 +128,7 @@ def plot_diagonal( x_ticks = np.unique(reference) x_ticks.sort() - for (_, deltas), _cy in zip(dict_vals.items(), cy): + for deltas, _cy in zip(data, cy): ax.plot( reference, deltas, @@ -137,7 +145,7 @@ def plot_diagonal( _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)]) ax.set(xlim=tuple(_lims), ylim=tuple(_lims)) - for (method, deltas), _cy in zip(dict_vals.items(), cy): + for method, deltas, _cy in zip(columns, data, cy): slope, interc = np.polyfit(reference, deltas, 1) y_lr = np.array([slope * x + interc for x in _lims]) ax.plot( @@ -171,7 +179,8 @@ def plot_diagonal( def plot_shift( shift_prevs, - shift_dict, + columns, + data, *, pos_class=1, metric="acc", @@ -190,14 +199,14 @@ def plot_shift( ax.set_aspect("auto") ax.grid() - NUM_COLORS = len(shift_dict) + NUM_COLORS = len(data) cm = plt.get_cmap("tab10") if NUM_COLORS > 10: cm = plt.get_cmap("tab20") cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) shift_prevs = shift_prevs[:, pos_class] - for (method, shifts), _cy in zip(shift_dict.items(), cy): + for method, shifts, _cy in zip(columns, data, cy): ax.plot( shift_prevs, shifts,