From 3345514c99f1b44c43a48fc232cefa3f10d6b750 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Fri, 27 Oct 2023 12:37:18 +0200 Subject: [PATCH] diag plot fixed, opts, avg plot, best score added --- .gitignore | 3 +- TODO.html | 67 ++++- TODO.md | 21 +- conf.yaml | 149 ++++++----- poetry.lock | 40 ++- pyproject.toml | 1 + quacc/dataset.py | 24 +- quacc/environ.py | 72 ----- quacc/environment.py | 85 ++++++ quacc/estimator.py | 119 ++++++--- quacc/evaluation/baseline.py | 138 ++++++---- quacc/evaluation/comp.py | 82 +++--- quacc/evaluation/method.py | 136 ++++------ quacc/evaluation/report.py | 496 +++++++++++++++++++++++++---------- quacc/main.py | 56 ++-- quacc/plot.py | 170 ++++++------ quacc/utils.py | 25 ++ 17 files changed, 1068 insertions(+), 616 deletions(-) delete mode 100644 quacc/environ.py create mode 100644 quacc/environment.py diff --git a/.gitignore b/.gitignore index 1ae9719..f450568 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ lipton_bbse/__pycache__/* elsahar19_rca/__pycache__/* *.coverage .coverage -scp_sync.py \ No newline at end of file +scp_sync.py +out/* \ No newline at end of file diff --git a/TODO.html b/TODO.html index 35dd2f8..d733192 100644 --- a/TODO.html +++ b/TODO.html @@ -41,12 +41,67 @@ diff --git a/TODO.md b/TODO.md index 0bb2cb4..f5ce0a2 100644 --- a/TODO.md +++ b/TODO.md @@ -1,16 +1,17 @@ - [x] aggiungere media tabelle - [x] plot; 3 tipi (appunti + email + garg) -- [ ] sistemare kfcv baseline +- [x] sistemare kfcv baseline - [x] aggiungere metodo con CC oltre SLD - [x] prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1) - [x] variare parametro recalibration in SLD -- [ ] plot collettivo, con sulla x lo shift e prenda in considerazione tutti i training set, facendo la media sui 9 casi (ogni line è un metodo), risultati non ottimizzati e ottimizzati -- [ ] varianti recalib: bcts, SLD (provare exact_train_prev=False) -- [ ] vedere cosa usa garg di validation size -- [ ] per model selection testare il parametro c del classificatore, si esplora in np.logscale(-3,3, 7) oppure np.logscale(-4, 4, 9), parametro class_weight si esplora in None oppure "balanced"; va usato qp.model_selection.GridSearchQ in funzione di mae come errore, UPP come protocollo - - qp.train_test_split per avere val_train e val_val +- [x] fix grafico diagonal + - seaborn example gallery +- [x] varianti recalib: bcts, SLD (provare exact_train_prev=False) +- [x] vedere cosa usa garg di validation size +- [x] per model selection testare il parametro c del classificatore, si esplora in np.logscale(-3,3, 7) oppure np.logscale(-4, 4, 9), parametro class_weight si esplora in None oppure "balanced"; va usato qp.model_selection.GridSearchQ in funzione di mae come errore, UPP come protocollo + - qp.train_test_split per avere v_train e v_val - GridSearchQ( model: BaseQuantifier, param_grid: { @@ -24,7 +25,7 @@ timeout=-1, n_jobs=-2, verbose=True).fit(V_tr) - - salvare il best score ottenuto da ogni applicazione di GridSearchQ - - nel caso di bin fare media dei due best score - -- seaborn example gallery \ No newline at end of file +- [x] plot collettivo, con sulla x lo shift e prenda in considerazione tutti i training set, facendo la media sui 9 casi (ogni line è un metodo), risultati non ottimizzati e ottimizzati +- [x] salvare il best score ottenuto da ogni applicazione di GridSearchQ + - nel caso di bin fare media dei due best score +- [x] import baselines diff --git a/conf.yaml b/conf.yaml index 50a5dd0..eef0123 100644 --- a/conf.yaml +++ b/conf.yaml @@ -1,71 +1,102 @@ +debug_conf: &debug_conf + global: + METRICS: + - acc + DATASET_N_PREVS: 1 -exec: [] + datasets: + - DATASET_NAME: rcv1 + DATASET_TARGET: CCAT -commons: - - DATASET_NAME: rcv1 - DATASET_TARGET: CCAT + plot_confs: + debug: + PLOT_ESTIMATORS: + # - mul_sld_bcts + - mul_sld + - ref + - atc_mc + - atc_ne + +test_conf: &test_conf + global: METRICS: - acc - f1 - DATASET_N_PREVS: 9 - - DATASET_NAME: imdb + DATASET_N_PREVS: 3 + + datasets: + - DATASET_NAME: rcv1 + DATASET_TARGET: CCAT + + plot_confs: + best_vs_atc: + PLOT_ESTIMATORS: + - bin_sld + - bin_sld_bcts + - bin_sld_gs + - mul_sld + - mul_sld_bcts + - mul_sld_gs + - ref + - atc_mc + - atc_ne + +main_conf: &main_conf + global: METRICS: - acc - f1 DATASET_N_PREVS: 9 -confs: + datasets: + - DATASET_NAME: rcv1 + DATASET_TARGET: CCAT + datasets_bck: + - DATASET_NAME: rcv1 + DATASET_TARGET: GCAT + - DATASET_NAME: rcv1 + DATASET_TARGET: MCAT + - DATASET_NAME: imdb - all_mul_vs_atc: - COMP_ESTIMATORS: - - our_mul_SLD - - our_mul_SLD_nbvs - - our_mul_SLD_bcts - - our_mul_SLD_ts - - our_mul_SLD_vs - - our_mul_CC - - ref - - atc_mc - - atc_ne + plot_confs: + gs_vs_atc: + PLOT_ESTIMATORS: + - mul_sld_gs + - bin_sld_gs + - ref + - atc_mc + - atc_ne + PLOT_STDEV: true + best_vs_atc: + PLOT_ESTIMATORS: + - mul_sld_bcts + - mul_sld_gs + - bin_sld_bcts + - bin_sld_gs + - ref + - atc_mc + - atc_ne + all_vs_atc: + PLOT_ESTIMATORS: + - bin_sld + - bin_sld_bcts + - bin_sld_gs + - mul_sld + - mul_sld_bcts + - mul_sld_gs + - ref + - atc_mc + - atc_ne + best_vs_all: + PLOT_ESTIMATORS: + - bin_sld_bcts + - bin_sld_gs + - mul_sld_bcts + - mul_sld_gs + - ref + - kfcv + - atc_mc + - atc_ne + - doc_feat - all_bin_vs_atc: - COMP_ESTIMATORS: - - our_bin_SLD - - our_bin_SLD_nbvs - - our_bin_SLD_bcts - - our_bin_SLD_ts - - our_bin_SLD_vs - - our_bin_CC - - ref - - atc_mc - - atc_ne - - best_our_vs_atc: - COMP_ESTIMATORS: - - our_bin_SLD - - our_bin_SLD_bcts - - our_bin_SLD_vs - - our_bin_CC - - our_mul_SLD - - our_mul_SLD_bcts - - our_mul_SLD_vs - - our_mul_CC - - ref - - atc_mc - - atc_ne - - best_our_vs_all: - COMP_ESTIMATORS: - - our_bin_SLD - - our_bin_SLD_bcts - - our_bin_SLD_vs - - our_bin_CC - - our_mul_SLD - - our_mul_SLD_bcts - - our_mul_SLD_vs - - our_mul_CC - - ref - - kfcv - - atc_mc - - atc_ne - - doc_feat +exec: *main_conf \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 7d7365d..c7dcbea 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1204,6 +1204,44 @@ files = [ {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, ] +[[package]] +name = "win11toast" +version = "0.32" +description = "Toast notifications for Windows 10 and 11" +optional = false +python-versions = "*" +files = [ + {file = "win11toast-0.32-py3-none-any.whl", hash = "sha256:38ecf6625374cbeebce4f3eda20cef0b2c468fedda23d95d883dfcdac98154a6"}, + {file = "win11toast-0.32.tar.gz", hash = "sha256:640650374285ef51bcad4453a3404f502e5b746e4a7fd7d974064a73ae808e33"}, +] + +[package.dependencies] +winsdk = "*" + +[[package]] +name = "winsdk" +version = "1.0.0b10" +description = "Python bindings for the Windows SDK" +optional = false +python-versions = ">=3.8" +files = [ + {file = "winsdk-1.0.0b10-cp310-cp310-win32.whl", hash = "sha256:90f75c67e166d588a045bcde0117a4631c705904f7af4ac42644479dcf0d8c52"}, + {file = "winsdk-1.0.0b10-cp310-cp310-win_amd64.whl", hash = "sha256:c3be3fbf692b8888bac8c0712c490c080ab8976649ef01f9f6365947f4e5a8b1"}, + {file = "winsdk-1.0.0b10-cp310-cp310-win_arm64.whl", hash = "sha256:6ab69dd65d959d94939c21974a33f4f1dfa625106c8784435ecacbd8ff0bf74d"}, + {file = "winsdk-1.0.0b10-cp311-cp311-win32.whl", hash = "sha256:9ea4fdad9ca8a542198aee3c753ac164b8e2f550d760bb88815095d64750e0f5"}, + {file = "winsdk-1.0.0b10-cp311-cp311-win_amd64.whl", hash = "sha256:f12e25bbf0a658270203615677520b8170edf500fba11e0f80359c5dbf090676"}, + {file = "winsdk-1.0.0b10-cp311-cp311-win_arm64.whl", hash = "sha256:e77bce44a9ff151562bd261b2a1a8255e258bb10696d0d31ef63267a27628af1"}, + {file = "winsdk-1.0.0b10-cp312-cp312-win32.whl", hash = "sha256:775a55a71e05ec2aa262c1fd67d80f270d4186bbdbbee2f43c9c412cf76f0761"}, + {file = "winsdk-1.0.0b10-cp312-cp312-win_amd64.whl", hash = "sha256:8231ce5f16e1fc88bb7dda0adf35633b5b26101eae3b0799083ca2177f03e4e5"}, + {file = "winsdk-1.0.0b10-cp312-cp312-win_arm64.whl", hash = "sha256:f4ab469ada19b34ccfc69a148090f98b40a1da1da797b50b9cbba0c090c365a5"}, + {file = "winsdk-1.0.0b10-cp38-cp38-win32.whl", hash = "sha256:786d6b50e4fcb8af2d701d7400c74e1c3f3ab7766ed1dfd516cdd6688072ea87"}, + {file = "winsdk-1.0.0b10-cp38-cp38-win_amd64.whl", hash = "sha256:1d4fdd1f79b41b64fedfbc478a29112edf2076e1a61001eccb536c0568510e74"}, + {file = "winsdk-1.0.0b10-cp39-cp39-win32.whl", hash = "sha256:4f04d3e50eeb8ca5fe4eb2e39785f3fa594199819acdfb23a10aaef4b97699ad"}, + {file = "winsdk-1.0.0b10-cp39-cp39-win_amd64.whl", hash = "sha256:7948bc5d8a02d73b1db043788d32b2988b8e7e29a25e503c21d34478e630eaf1"}, + {file = "winsdk-1.0.0b10-cp39-cp39-win_arm64.whl", hash = "sha256:342b1095cbd937865cee962676e279a1fd28896a0680724fcf9c65157e7ebdb7"}, + {file = "winsdk-1.0.0b10.tar.gz", hash = "sha256:8f39ea759626797449371f857c9085b84bb9f3b6d493dc6525e2cedcb3d15ea2"}, +] + [[package]] name = "xlrd" version = "2.0.1" @@ -1223,4 +1261,4 @@ test = ["pytest", "pytest-cov"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "0ce0e6b058900e7db2939e7eb047a1f868c88de67def370c1c1fa0ba532df0b0" +content-hash = "c98b7510ac055b667340b52e1b0b0777370e68d325d3149cb1fef42b6f1ec50a" diff --git a/pyproject.toml b/pyproject.toml index d9ce79a..336e224 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ pytest = "^7.4.0" pylance = "^0.5.9" pytest-mock = "^3.11.1" pytest-cov = "^4.1.0" +win11toast = "^0.32" [tool.pytest.ini_options] addopts = "--cov=quacc --capture=tee-sys" diff --git a/quacc/dataset.py b/quacc/dataset.py index c63552f..9362da8 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -40,21 +40,22 @@ class Dataset: self.n_prevs = n_prevalences def __spambase(self): - return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test + return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test def __imdb(self): - return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test + return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test def __rcv1(self): n_train = 23149 available_targets = ["CCAT", "GCAT", "MCAT"] if self._target is None or self._target not in available_targets: - raise ValueError("Invalid target") + raise ValueError(f"Invalid target {self._target}") dataset = fetch_rcv1() target_index = np.where(dataset.target_names == self._target)[0] - all_train_d, test_d = dataset.data[:n_train, :], dataset.data[n_train:, :] + all_train_d = dataset.data[:n_train, :] + test_d = dataset.data[n_train:, :] labels = dataset.target[:, target_index].toarray().flatten() all_train_l, test_l = labels[:n_train], labels[n_train:] all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1]) @@ -62,6 +63,21 @@ class Dataset: return all_train, test + def get_raw(self, validation=True) -> DatasetSample: + all_train, test = { + "spambase": self.__spambase, + "imdb": self.__imdb, + "rcv1": self.__rcv1, + }[self._name]() + + train, val = all_train, None + if validation: + train, val = all_train.split_stratified( + train_prop=TRAIN_VAL_PROP, random_state=0 + ) + + return DatasetSample(train, val, test) + def get(self) -> List[DatasetSample]: all_train, test = { "spambase": self.__spambase, diff --git a/quacc/environ.py b/quacc/environ.py deleted file mode 100644 index 1177964..0000000 --- a/quacc/environ.py +++ /dev/null @@ -1,72 +0,0 @@ -import yaml - -defalut_env = { - "DATASET_NAME": "rcv1", - "DATASET_TARGET": "CCAT", - "METRICS": ["acc", "f1"], - "COMP_ESTIMATORS": [ - "our_bin_SLD", - "our_bin_SLD_nbvs", - "our_bin_SLD_bcts", - "our_bin_SLD_ts", - "our_bin_SLD_vs", - "our_bin_CC", - "our_mul_SLD", - "our_mul_SLD_nbvs", - "our_mul_SLD_bcts", - "our_mul_SLD_ts", - "our_mul_SLD_vs", - "our_mul_CC", - "ref", - "kfcv", - "atc_mc", - "atc_ne", - "doc_feat", - "rca", - "rca_star", - ], - "DATASET_N_PREVS": 9, - "OUT_DIR_NAME": "output", - "PLOT_DIR_NAME": "plot", - "PROTOCOL_N_PREVS": 21, - "PROTOCOL_REPEATS": 100, - "SAMPLE_SIZE": 1000, -} - - -class Environ: - def __init__(self, **kwargs): - self.exec = [] - self.confs = {} - self.__setdict(kwargs) - - def __setdict(self, d): - for k, v in d.items(): - self.__setattr__(k, v) - - def load_conf(self): - with open("conf.yaml", "r") as f: - confs = yaml.safe_load(f) - - for common in confs["commons"]: - name = common["DATASET_NAME"] - if "DATASET_TARGET" in common: - name += "_" + common["DATASET_TARGET"] - for k, d in confs["confs"].items(): - _k = f"{name}_{k}" - self.confs[_k] = common | d - self.exec.append(_k) - - if "exec" in confs: - if len(confs["exec"]) > 0: - self.exec = confs["exec"] - - def __iter__(self): - self.load_conf() - for _conf in self.exec: - if _conf in self.confs: - self.__setdict(self.confs[_conf]) - yield _conf - - -env = Environ(**defalut_env) diff --git a/quacc/environment.py b/quacc/environment.py new file mode 100644 index 0000000..1a7a832 --- /dev/null +++ b/quacc/environment.py @@ -0,0 +1,85 @@ +import yaml + +defalut_env = { + "DATASET_NAME": "rcv1", + "DATASET_TARGET": "CCAT", + "METRICS": ["acc", "f1"], + "COMP_ESTIMATORS": [], + "PLOT_ESTIMATORS": [], + "PLOT_STDEV": False, + "DATASET_N_PREVS": 9, + "OUT_DIR_NAME": "output", + "OUT_DIR": None, + "PLOT_DIR_NAME": "plot", + "PLOT_OUT_DIR": None, + "DATASET_DIR_UPDATE": False, + "PROTOCOL_N_PREVS": 21, + "PROTOCOL_REPEATS": 100, + "SAMPLE_SIZE": 1000, +} + + +class environ: + _instance = None + + def __init__(self, **kwargs): + self.exec = [] + self.confs = [] + self._default = kwargs + self.__setdict(kwargs) + self.load_conf() + + def __setdict(self, d): + for k, v in d.items(): + self.__setattr__(k, v) + if len(self.PLOT_ESTIMATORS) == 0: + self.PLOT_ESTIMATORS = self.COMP_ESTIMATORS + + def __class_getitem__(cls, k): + env = cls.get() + return env.__getattribute__(k) + + def load_conf(self): + with open("conf.yaml", "r") as f: + confs = yaml.safe_load(f)["exec"] + + _global = confs["global"] + _estimators = set() + for pc in confs["plot_confs"].values(): + _estimators = _estimators.union(set(pc["PLOT_ESTIMATORS"])) + _global["COMP_ESTIMATORS"] = list(_estimators) + + self.plot_confs = confs["plot_confs"] + + for dataset in confs["datasets"]: + self.confs.append(_global | dataset) + + def get_confs(self): + for _conf in self.confs: + self.__setdict(self._default) + self.__setdict(_conf) + if "DATASET_TARGET" not in _conf: + self.DATASET_TARGET = None + + name = self.DATASET_NAME + if self.DATASET_TARGET is not None: + name += f"_{self.DATASET_TARGET}" + name += f"_{self.DATASET_N_PREVS}prevs" + + yield name + + def get_plot_confs(self): + for k, pc in self.plot_confs.items(): + if "PLOT_ESTIMATORS" in pc: + self.PLOT_ESTIMATORS = pc["PLOT_ESTIMATORS"] + if "PLOT_STDEV" in pc: + self.PLOT_STDEV = pc["PLOT_STDEV"] + + name = self.DATASET_NAME + if self.DATASET_TARGET is not None: + name += f"_{self.DATASET_TARGET}" + name += f"_{k}" + yield name + + +env = environ(**defalut_env) diff --git a/quacc/estimator.py b/quacc/estimator.py index 2f9a92c..216b8a1 100644 --- a/quacc/estimator.py +++ b/quacc/estimator.py @@ -2,8 +2,11 @@ import math from abc import abstractmethod import numpy as np +import quapy as qp from quapy.data import LabelledCollection from quapy.method.aggregative import CC, SLD +from quapy.model_selection import GridSearchQ +from quapy.protocol import UPP from sklearn.base import BaseEstimator from sklearn.linear_model import LogisticRegression from sklearn.model_selection import cross_val_predict @@ -12,6 +15,24 @@ from quacc.data import ExtendedCollection class AccuracyEstimator: + def __init__(self): + self.fit_score = None + + def _gs_params(self, t_val: LabelledCollection): + return { + "param_grid": { + "classifier__C": np.logspace(-3, 3, 7), + "classifier__class_weight": [None, "balanced"], + "recalib": [None, "bcts"], + }, + "protocol": UPP(t_val, repeats=1000), + "error": qp.error.mae, + "refit": False, + "timeout": -1, + "n_jobs": None, + "verbose": True, + } + def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection: if not pred_proba: pred_proba = self.c_model.predict_proba(base.X) @@ -26,17 +47,17 @@ class AccuracyEstimator: ... -class MulticlassAccuracyEstimator(AccuracyEstimator): - def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs): - self.c_model = c_model - if q_model == "SLD": - available_args = ["recalib"] - sld_args = {k: v for k, v in kwargs.items() if k in available_args} - self.q_model = SLD(LogisticRegression(), **sld_args) - elif q_model == "CC": - self.q_model = CC(LogisticRegression()) +AE = AccuracyEstimator + +class MulticlassAccuracyEstimator(AccuracyEstimator): + def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None): + super().__init__() + self.c_model = c_model + self._q_model_name = q_model.upper() self.e_train = None + self.gs = gs + self.recalib = recalib def fit(self, train: LabelledCollection | ExtendedCollection): # check if model is fit @@ -45,12 +66,26 @@ class MulticlassAccuracyEstimator(AccuracyEstimator): pred_prob_train = cross_val_predict( self.c_model, *train.Xy, method="predict_proba" ) - self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train) else: self.e_train = train - self.q_model.fit(self.e_train) + if self._q_model_name == "SLD": + if self.gs: + t_train, t_val = self.e_train.split_stratified(0.6, random_state=0) + gs_params = self._gs_params(t_val) + self.q_model = GridSearchQ( + SLD(LogisticRegression()), + **gs_params, + ) + self.q_model.fit(t_train) + self.fit_score = self.q_model.best_score_ + else: + self.q_model = SLD(LogisticRegression(), recalib=self.recalib) + self.q_model.fit(self.e_train) + elif self._q_model_name == "CC": + self.q_model = CC(LogisticRegression()) + self.q_model.fit(self.e_train) def estimate(self, instances, ext=False): if not ext: @@ -62,10 +97,14 @@ class MulticlassAccuracyEstimator(AccuracyEstimator): estim_prev = self.q_model.quantify(e_inst) return self._check_prevalence_classes( - self.e_train.classes_, self.q_model.classes_, estim_prev + self.e_train.classes_, self.q_model, estim_prev ) - def _check_prevalence_classes(self, true_classes, estim_classes, estim_prev): + def _check_prevalence_classes(self, true_classes, q_model, estim_prev): + if isinstance(q_model, GridSearchQ): + estim_classes = q_model.best_model().classes_ + else: + estim_classes = q_model.classes_ for _cls in true_classes: if _cls not in estim_classes: estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0) @@ -73,17 +112,13 @@ class MulticlassAccuracyEstimator(AccuracyEstimator): class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): - def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs): + def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None): + super().__init__() self.c_model = c_model - if q_model == "SLD": - available_args = ["recalib"] - sld_args = {k: v for k, v in kwargs.items() if k in available_args} - self.q_model_0 = SLD(LogisticRegression(), **sld_args) - self.q_model_1 = SLD(LogisticRegression(), **sld_args) - elif q_model == "CC": - self.q_model_0 = CC(LogisticRegression()) - self.q_model_1 = CC(LogisticRegression()) - + self._q_model_name = q_model.upper() + self.q_models = [] + self.gs = gs + self.recalib = recalib self.e_train = None def fit(self, train: LabelledCollection | ExtendedCollection): @@ -99,10 +134,34 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): self.e_train = train self.n_classes = self.e_train.n_classes - [e_train_0, e_train_1] = self.e_train.split_by_pred() + e_trains = self.e_train.split_by_pred() - self.q_model_0.fit(e_train_0) - self.q_model_1.fit(e_train_1) + if self._q_model_name == "SLD": + fit_scores = [] + for e_train in e_trains: + if self.gs: + t_train, t_val = e_train.split_stratified(0.6, random_state=0) + gs_params = self._gs_params(t_val) + q_model = GridSearchQ( + SLD(LogisticRegression()), + **gs_params, + ) + q_model.fit(t_train) + fit_scores.append(q_model.best_score_) + self.q_models.append(q_model) + else: + q_model = SLD(LogisticRegression(), recalib=self.recalib) + q_model.fit(e_train) + self.q_models.append(q_model) + + if self.gs: + self.fit_score = np.mean(fit_scores) + + elif self._q_model_name == "CC": + for e_train in e_trains: + q_model = CC(LogisticRegression()) + q_model.fit(e_train) + self.q_models.append(q_model) def estimate(self, instances, ext=False): # TODO: test @@ -114,15 +173,13 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): _ncl = int(math.sqrt(self.n_classes)) s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst) - [estim_prev_0, estim_prev_1] = [ + estim_prevs = [ self._quantify_helper(inst, norm, q_model) - for (inst, norm, q_model) in zip( - s_inst, norms, [self.q_model_0, self.q_model_1] - ) + for (inst, norm, q_model) in zip(s_inst, norms, self.q_models) ] estim_prev = [] - for prev_row in zip(estim_prev_0, estim_prev_1): + for prev_row in zip(*estim_prevs): for prev in prev_row: estim_prev.append(prev) diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py index e36a492..8ec32d4 100644 --- a/quacc/evaluation/baseline.py +++ b/quacc/evaluation/baseline.py @@ -1,23 +1,35 @@ +from functools import wraps from statistics import mean import numpy as np import sklearn.metrics as metrics from quapy.data import LabelledCollection -from quapy.protocol import ( - AbstractStochasticSeededProtocol, - OnLabelledCollectionProtocol, -) +from quapy.protocol import AbstractStochasticSeededProtocol +from scipy.sparse import issparse from sklearn.base import BaseEstimator from sklearn.model_selection import cross_validate -import elsahar19_rca.rca as rca -import garg22_ATC.ATC_helper as atc -import guillory21_doc.doc as doc -import jiang18_trustscore.trustscore as trustscore +import baselines.atc as atc +import baselines.doc as doc +import baselines.impweight as iw +import baselines.rca as rcalib from .report import EvaluationReport +_baselines = {} + +def baseline(func): + @wraps(func) + def wrapper(c_model, validation, protocol): + return func(c_model, validation, protocol) + + _baselines[func.__name__] = wrapper + + return wrapper + + +@baseline def kfcv( c_model: BaseEstimator, validation: LabelledCollection, @@ -31,9 +43,6 @@ def kfcv( acc_score = mean(scores["test_accuracy"]) f1_score = mean(scores["test_f1_macro"]) - # ensure that the protocol returns a LabelledCollection for each iteration - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(name="kfcv") for test in protocol(): test_preds = c_model_predict(test.X) @@ -50,12 +59,12 @@ def kfcv( return report -def reference( +@baseline +def ref( c_model: BaseEstimator, validation: LabelledCollection, protocol: AbstractStochasticSeededProtocol, ): - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") c_model_predict = getattr(c_model, "predict_proba") report = EvaluationReport(name="ref") for test in protocol(): @@ -70,6 +79,7 @@ def reference( return report +@baseline def atc_mc( c_model: BaseEstimator, validation: LabelledCollection, @@ -86,9 +96,6 @@ def atc_mc( val_preds = np.argmax(val_probs, axis=-1) _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds) - # ensure that the protocol returns a LabelledCollection for each iteration - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(name="atc_mc") for test in protocol(): ## Load OOD test data probs @@ -110,6 +117,7 @@ def atc_mc( return report +@baseline def atc_ne( c_model: BaseEstimator, validation: LabelledCollection, @@ -126,9 +134,6 @@ def atc_ne( val_preds = np.argmax(val_probs, axis=-1) _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds) - # ensure that the protocol returns a LabelledCollection for each iteration - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(name="atc_ne") for test in protocol(): ## Load OOD test data probs @@ -150,22 +155,7 @@ def atc_ne( return report -def trust_score( - c_model: BaseEstimator, - validation: LabelledCollection, - test: LabelledCollection, - predict_method="predict", -): - c_model_predict = getattr(c_model, predict_method) - - test_pred = c_model_predict(test.X) - - trust_model = trustscore.TrustScore() - trust_model.fit(validation.X, validation.y) - - return trust_model.get_score(test.X, test_pred) - - +@baseline def doc_feat( c_model: BaseEstimator, validation: LabelledCollection, @@ -179,9 +169,6 @@ def doc_feat( val_preds = np.argmax(val_probs, axis=-1) v1acc = np.mean(val_preds == val_labels) * 100 - # ensure that the protocol returns a LabelledCollection for each iteration - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(name="doc_feat") for test in protocol(): test_probs = c_model_predict(test.X) @@ -194,26 +181,25 @@ def doc_feat( return report -def rca_score( +@baseline +def rca( c_model: BaseEstimator, validation: LabelledCollection, protocol: AbstractStochasticSeededProtocol, predict_method="predict", ): + """elsahar19""" c_model_predict = getattr(c_model, predict_method) val_pred1 = c_model_predict(validation.X) - # ensure that the protocol returns a LabelledCollection for each iteration - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(name="rca") for test in protocol(): try: test_pred = c_model_predict(test.X) - c_model2 = rca.clone_fit(c_model, test.X, test_pred) + c_model2 = rcalib.clone_fit(c_model, test.X, test_pred) c_model2_predict = getattr(c_model2, predict_method) val_pred2 = c_model2_predict(validation.X) - rca_score = 1.0 - rca.get_score(val_pred1, val_pred2, validation.y) + rca_score = 1.0 - rcalib.get_score(val_pred1, val_pred2, validation.y) meta_score = abs(rca_score - metrics.accuracy_score(test.y, test_pred)) report.append_row(test.prevalence(), acc=meta_score, acc_score=rca_score) except ValueError: @@ -224,32 +210,33 @@ def rca_score( return report -def rca_star_score( +@baseline +def rca_star( c_model: BaseEstimator, validation: LabelledCollection, protocol: AbstractStochasticSeededProtocol, predict_method="predict", ): + """elsahar19""" c_model_predict = getattr(c_model, predict_method) validation1, validation2 = validation.split_stratified( train_prop=0.5, random_state=0 ) val1_pred = c_model_predict(validation1.X) - c_model1 = rca.clone_fit(c_model, validation1.X, val1_pred) + c_model1 = rcalib.clone_fit(c_model, validation1.X, val1_pred) c_model1_predict = getattr(c_model1, predict_method) val2_pred1 = c_model1_predict(validation2.X) - # ensure that the protocol returns a LabelledCollection for each iteration - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - report = EvaluationReport(name="rca_star") for test in protocol(): try: test_pred = c_model_predict(test.X) - c_model2 = rca.clone_fit(c_model, test.X, test_pred) + c_model2 = rcalib.clone_fit(c_model, test.X, test_pred) c_model2_predict = getattr(c_model2, predict_method) val2_pred2 = c_model2_predict(validation2.X) - rca_star_score = 1.0 - rca.get_score(val2_pred1, val2_pred2, validation2.y) + rca_star_score = 1.0 - rcalib.get_score( + val2_pred1, val2_pred2, validation2.y + ) meta_score = abs(rca_star_score - metrics.accuracy_score(test.y, test_pred)) report.append_row( test.prevalence(), acc=meta_score, acc_score=rca_star_score @@ -260,3 +247,52 @@ def rca_star_score( ) return report + + +@baseline +def logreg( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, + predict_method="predict", +): + c_model_predict = getattr(c_model, predict_method) + + val_preds = c_model_predict(validation.X) + + report = EvaluationReport(name="logreg") + for test in protocol(): + wx = iw.logreg(validation.X, validation.y, test.X) + test_preds = c_model_predict(test.X) + estim_acc = iw.get_acc(val_preds, validation.y, wx) + true_acc = metrics.accuracy_score(test.y, test_preds) + meta_score = abs(estim_acc - true_acc) + report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc) + + return report + + +@baseline +def kdex2( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, + predict_method="predict", +): + c_model_predict = getattr(c_model, predict_method) + + val_preds = c_model_predict(validation.X) + log_likelihood_val = iw.kdex2_lltr(validation.X) + Xval = validation.X.toarray() if issparse(validation.X) else validation.X + + report = EvaluationReport(name="kdex2") + for test in protocol(): + Xte = test.X.toarray() if issparse(test.X) else test.X + wx = iw.kdex2_weights(Xval, Xte, log_likelihood_val) + test_preds = c_model_predict(Xte) + estim_acc = iw.get_acc(val_preds, validation.y, wx) + true_acc = metrics.accuracy_score(test.y, test_preds) + meta_score = abs(estim_acc - true_acc) + report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc) + + return report diff --git a/quacc/evaluation/comp.py b/quacc/evaluation/comp.py index b8c403b..e5da34d 100644 --- a/quacc/evaluation/comp.py +++ b/quacc/evaluation/comp.py @@ -1,17 +1,18 @@ +import logging as log import multiprocessing import time -import traceback from typing import List +import numpy as np import pandas as pd import quapy as qp from quapy.protocol import APP from sklearn.linear_model import LogisticRegression from quacc.dataset import Dataset -from quacc.environ import env +from quacc.environment import env from quacc.evaluation import baseline, method -from quacc.evaluation.report import DatasetReport, EvaluationReport +from quacc.evaluation.report import CompReport, DatasetReport, EvaluationReport qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE @@ -19,27 +20,7 @@ pd.set_option("display.float_format", "{:.4f}".format) class CompEstimator: - __dict = { - "our_bin_SLD": method.evaluate_bin_sld, - "our_mul_SLD": method.evaluate_mul_sld, - "our_bin_SLD_nbvs": method.evaluate_bin_sld_nbvs, - "our_mul_SLD_nbvs": method.evaluate_mul_sld_nbvs, - "our_bin_SLD_bcts": method.evaluate_bin_sld_bcts, - "our_mul_SLD_bcts": method.evaluate_mul_sld_bcts, - "our_bin_SLD_ts": method.evaluate_bin_sld_ts, - "our_mul_SLD_ts": method.evaluate_mul_sld_ts, - "our_bin_SLD_vs": method.evaluate_bin_sld_vs, - "our_mul_SLD_vs": method.evaluate_mul_sld_vs, - "our_bin_CC": method.evaluate_bin_cc, - "our_mul_CC": method.evaluate_mul_cc, - "ref": baseline.reference, - "kfcv": baseline.kfcv, - "atc_mc": baseline.atc_mc, - "atc_ne": baseline.atc_ne, - "doc_feat": baseline.doc_feat, - "rca": baseline.rca_score, - "rca_star": baseline.rca_star_score, - } + __dict = method._methods | baseline._baselines def __class_getitem__(cls, e: str | List[str]): if isinstance(e, str): @@ -48,30 +29,34 @@ class CompEstimator: except KeyError: raise KeyError(f"Invalid estimator: estimator {e} does not exist") elif isinstance(e, list): - try: - return [cls.__dict[est] for est in e] - except KeyError as ke: + _subtr = [k for k in e if k not in cls.__dict] + if len(_subtr) > 0: raise KeyError( - f"Invalid estimator: estimator {ke.args[0]} does not exist" + f"Invalid estimator: estimator {_subtr[0]} does not exist" ) + return [fun for k, fun in cls.__dict.items() if k in e] + CE = CompEstimator -def fit_and_estimate(_estimate, train, validation, test): +def fit_and_estimate(_estimate, train, validation, test, _env=None): + _env = env if _env is None else _env model = LogisticRegression() model.fit(*train.Xy) protocol = APP( - test, n_prevalences=env.PROTOCOL_N_PREVS, repeats=env.PROTOCOL_REPEATS + test, + n_prevalences=_env.PROTOCOL_N_PREVS, + repeats=_env.PROTOCOL_REPEATS, + return_type="labelled_collection", ) start = time.time() try: result = _estimate(model, validation, protocol) except Exception as e: - print(f"Method {_estimate.__name__} failed.") - traceback(e) + log.error(f"Method {_estimate.__name__} failed. Exception: {e}") return { "name": _estimate.__name__, "result": None, @@ -79,7 +64,7 @@ def fit_and_estimate(_estimate, train, validation, test): } end = time.time() - print(f"{_estimate.__name__}: {end-start:.2f}s") + log.info(f"{_estimate.__name__} finished [took {end-start:.4f}s]") return { "name": _estimate.__name__, @@ -91,13 +76,17 @@ def fit_and_estimate(_estimate, train, validation, test): def evaluate_comparison( dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"] ) -> EvaluationReport: + # with multiprocessing.Pool(1) as pool: with multiprocessing.Pool(len(estimators)) as pool: dr = DatasetReport(dataset.name) + log.info(f"dataset {dataset.name}") for d in dataset(): - print(f"train prev.: {d.train_prev}") - start = time.time() + log.info(f"train prev.: {np.around(d.train_prev, decimals=2)}") + tstart = time.time() tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]] - results = [pool.apply_async(fit_and_estimate, t) for t in tasks] + results = [ + pool.apply_async(fit_and_estimate, t, {"_env": env}) for t in tasks + ] results_got = [] for _r in results: @@ -106,19 +95,22 @@ def evaluate_comparison( if r["result"] is not None: results_got.append(r) except Exception as e: - print(e) + log.error( + f"Dataset sample {d.train[1]:.2f} of dataset {dataset.name} failed. Exception: {e}" + ) - er = EvaluationReport.combine_reports( - *[r["result"] for r in results_got], + tend = time.time() + times = {r["name"]: r["time"] for r in results_got} + times["tot"] = tend - tstart + log.info( + f"Dataset sample {d.train[1]:.2f} of dataset {dataset.name} finished [took {times['tot']:.4f}s" + ) + dr += CompReport( + [r["result"] for r in results_got], name=dataset.name, train_prev=d.train_prev, valid_prev=d.validation_prev, + times=times, ) - times = {r["name"]: r["time"] for r in results_got} - end = time.time() - times["tot"] = end - start - er.times = times - dr.add(er) - print() return dr diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index 67f8878..b15990a 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -1,10 +1,9 @@ +from functools import wraps + import numpy as np import sklearn.metrics as metrics from quapy.data import LabelledCollection -from quapy.protocol import ( - AbstractStochasticSeededProtocol, - OnLabelledCollectionProtocol, -) +from quapy.protocol import AbstractStochasticSeededProtocol from sklearn.base import BaseEstimator import quacc.error as error @@ -16,14 +15,23 @@ from ..estimator import ( MulticlassAccuracyEstimator, ) +_methods = {} + + +def method(func): + @wraps(func) + def wrapper(c_model, validation, protocol): + return func(c_model, validation, protocol) + + _methods[func.__name__] = wrapper + + return wrapper + def estimate( estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol, ): - # ensure that the protocol returns a LabelledCollection for each iteration - protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") - base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], [] for sample in protocol(): e_sample, pred_proba = estimator.extend(sample) @@ -61,6 +69,8 @@ def evaluation_report( f1=abs(error.f1(true_prev) - f1_score), ) + report.fit_score = estimator.fit_score + return report @@ -75,105 +85,51 @@ def evaluate( estimator: AccuracyEstimator = { "bin": BinaryQuantifierAccuracyEstimator, "mul": MulticlassAccuracyEstimator, - }[method](c_model, q_model=q_model, **kwargs) + }[method](c_model, q_model=q_model.upper(), **kwargs) estimator.fit(validation) _method = f"{method}_{q_model}" - for k, v in kwargs.items(): - _method += f"_{v}" + if "recalib" in kwargs: + _method += f"_{kwargs['recalib']}" + if ("gs", True) in kwargs.items(): + _method += "_gs" return evaluation_report(estimator, protocol, _method) -def evaluate_bin_sld( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "SLD") +@method +def bin_sld(c_model, validation, protocol) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "sld") -def evaluate_mul_sld( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "SLD") +@method +def mul_sld(c_model, validation, protocol) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "sld") -def evaluate_bin_sld_nbvs( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="nbvs") +@method +def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "sld", recalib="bcts") -def evaluate_mul_sld_nbvs( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="nbvs") +@method +def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "sld", recalib="bcts") -def evaluate_bin_sld_bcts( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="bcts") +@method +def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "sld", gs=True) -def evaluate_mul_sld_bcts( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="bcts") +@method +def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "sld", gs=True) -def evaluate_bin_sld_ts( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="ts") +@method +def bin_cc(c_model, validation, protocol) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "bin", "cc") -def evaluate_mul_sld_ts( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="ts") - - -def evaluate_bin_sld_vs( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="vs") - - -def evaluate_mul_sld_vs( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="vs") - - -def evaluate_bin_cc( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "bin", "CC") - - -def evaluate_mul_cc( - c_model: BaseEstimator, - validation: LabelledCollection, - protocol: AbstractStochasticSeededProtocol, -) -> EvaluationReport: - return evaluate(c_model, validation, protocol, "mul", "CC") +@method +def mul_cc(c_model, validation, protocol) -> EvaluationReport: + return evaluate(c_model, validation, protocol, "mul", "cc") diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index 3d14203..56019a9 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd from quacc import plot -from quacc.environ import env +from quacc.environment import env from quacc.utils import fmt_line_md @@ -13,191 +13,399 @@ class EvaluationReport: def __init__(self, name=None): self._prevs = [] self._dict = {} - self._g_prevs = None - self._g_dict = None + self.fit_score = None self.name = name if name is not None else "default" - self.times = {} - self.train_prev = None - self.valid_prev = None - self.target = "default" - def append_row(self, base: np.ndarray | Tuple, **row): - if isinstance(base, np.ndarray): - base = tuple(base.tolist()) - self._prevs.append(base) + def append_row(self, basep: np.ndarray | Tuple, **row): + bp = basep[1] + self._prevs.append(bp) for k, v in row.items(): - if (k, self.name) in self._dict: - self._dict[(k, self.name)].append(v) - else: - self._dict[(k, self.name)] = [v] - self._g_prevs = None + if k not in self._dict: + self._dict[k] = {} + if bp not in self._dict[k]: + self._dict[k][bp] = [] + self._dict[k][bp] = np.append(self._dict[k][bp], [v]) @property def columns(self): return self._dict.keys() - def group_by_prevs(self, metric: str = None): - if self._g_dict is None: - self._g_prevs = [] - self._g_dict = {k: [] for k in self._dict.keys()} + @property + def prevs(self): + return np.sort(np.unique([list(self._dict[_k].keys()) for _k in self._dict])) - for col, vals in self._dict.items(): - col_grouped = {} - for bp, v in zip(self._prevs, vals): - if bp not in col_grouped: - col_grouped[bp] = [] - col_grouped[bp].append(v) + # def group_by_prevs(self, metric: str = None, estimators: List[str] = None): + # if self._g_dict is None: + # self._g_prevs = [] + # self._g_dict = {k: [] for k in self._dict.keys()} - self._g_dict[col] = [ - vs - for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1]) - ] + # for col, vals in self._dict.items(): + # col_grouped = {} + # for bp, v in zip(self._prevs, vals): + # if bp not in col_grouped: + # col_grouped[bp] = [] + # col_grouped[bp].append(v) - self._g_prevs = sorted( - [(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()], - key=lambda bp: bp[1], + # self._g_dict[col] = [ + # vs + # for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1]) + # ] + + # self._g_prevs = sorted( + # [(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()], + # key=lambda bp: bp[1], + # ) + + # fg_dict = _filter_dict(self._g_dict, metric, estimators) + # return self._g_prevs, fg_dict + + # def merge(self, other): + # if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)): + # raise ValueError("other has not same base prevalences of self") + + # inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys())) + # if len(inters_keys) > 0: + # raise ValueError(f"self and other have matching keys {str(inters_keys)}.") + + # report = EvaluationReport() + # report._prevs = self._prevs + # report._dict = self._dict | other._dict + # return report + + +class CompReport: + def __init__( + self, + reports: List[EvaluationReport], + name="default", + train_prev=None, + valid_prev=None, + times=None, + ): + all_prevs = np.array([er.prevs for er in reports]) + if not np.all(all_prevs == all_prevs[0, :], axis=0).all(): + raise ValueError( + "Not all evaluation reports have the same base prevalences" + ) + uq_names, name_c = np.unique([er.name for er in reports], return_counts=True) + if np.sum(name_c) > uq_names.shape[0]: + _matching = uq_names[[c > 1 for c in name_c]] + raise ValueError( + f"Evaluation reports have matching names: {_matching.tolist()}." ) - # last_end = 0 - # for ind, bp in enumerate(self._prevs): - # if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]: - # continue + all_dicts = [{(k, er.name): v for k, v in er._dict.items()} for er in reports] + self._dict = {} + for d in all_dicts: + self._dict = self._dict | d - # self._g_prevs.append(bp) - # for col in self._dict.keys(): - # self._g_dict[col].append( - # stats.mean(self._dict[col][last_end : ind + 1]) - # ) + self.fit_scores = { + er.name: er.fit_score for er in reports if er.fit_score is not None + } + self.train_prev = train_prev + self.valid_prev = valid_prev + self.times = times - # last_end = ind + 1 + @property + def prevs(self): + return np.sort(np.unique([list(self._dict[_k].keys()) for _k in self._dict])) - filtered_g_dict = self._g_dict + @property + def cprevs(self): + return np.around([(1.0 - p, p) for p in self.prevs], decimals=2) + + def data(self, metric: str = None, estimators: List[str] = None) -> dict: + f_dict = self._dict.copy() if metric is not None: - filtered_g_dict = { - c1: ls for ((c0, c1), ls) in self._g_dict.items() if c0 == metric + f_dict = {(c0, c1): ls for ((c0, c1), ls) in f_dict.items() if c0 == metric} + if estimators is not None: + f_dict = { + (c0, c1): ls for ((c0, c1), ls) in f_dict.items() if c1 in estimators } + if (metric, estimators) != (None, None): + f_dict = {c1: ls for ((c0, c1), ls) in f_dict.items()} - return self._g_prevs, filtered_g_dict + return f_dict - def avg_by_prevs(self, metric: str = None): - g_prevs, g_dict = self.group_by_prevs(metric=metric) - - a_dict = {} - for col, vals in g_dict.items(): - a_dict[col] = [np.mean(vs) for vs in vals] - - return g_prevs, a_dict - - def avg_all(self, metric: str = None): - f_dict = self._dict - if metric is not None: - f_dict = {c1: ls for ((c0, c1), ls) in self._dict.items() if c0 == metric} - - a_dict = {} + def group_by_shift(self, metric: str = None, estimators: List[str] = None): + f_dict = self.data(metric=metric, estimators=estimators) + shift_prevs = np.around( + np.absolute(self.prevs - self.train_prev[1]), decimals=2 + ) + shift_dict = {col: {sp: [] for sp in shift_prevs} for col in f_dict.keys()} for col, vals in f_dict.items(): - a_dict[col] = [np.mean(vals)] + for sp, bp in zip(shift_prevs, self.prevs): + shift_dict[col][sp] = np.concatenate( + [shift_dict[col][sp], f_dict[col][bp]] + ) - return a_dict + return np.sort(np.unique(shift_prevs)), shift_dict - def get_dataframe(self, metric="acc"): - g_prevs, g_dict = self.avg_by_prevs(metric=metric) - a_dict = self.avg_all(metric=metric) - for col in g_dict.keys(): - g_dict[col].extend(a_dict[col]) + def avg_by_prevs(self, metric: str = None, estimators: List[str] = None): + f_dict = self.data(metric=metric, estimators=estimators) + return { + col: np.array([np.mean(vals[bp]) for bp in self.prevs]) + for col, vals in f_dict.items() + } + + def stdev_by_prevs(self, metric: str = None, estimators: List[str] = None): + f_dict = self.data(metric=metric, estimators=estimators) + return { + col: np.array([np.std(vals[bp]) for bp in self.prevs]) + for col, vals in f_dict.items() + } + + def avg_all(self, metric: str = None, estimators: List[str] = None): + f_dict = self.data(metric=metric, estimators=estimators) + return { + col: [np.mean(np.concatenate(list(vals.values())))] + for col, vals in f_dict.items() + } + + def get_dataframe(self, metric="acc", estimators=None): + avg_dict = self.avg_by_prevs(metric=metric, estimators=estimators) + all_dict = self.avg_all(metric=metric, estimators=estimators) + for col in avg_dict.keys(): + avg_dict[col] = np.append(avg_dict[col], all_dict[col]) return pd.DataFrame( - g_dict, - index=g_prevs + ["tot"], - columns=g_dict.keys(), + avg_dict, + index=self.prevs.tolist() + ["tot"], + columns=avg_dict.keys(), ) - def get_plot(self, mode="delta", metric="acc") -> Path: - if mode == "delta": - g_prevs, g_dict = self.group_by_prevs(metric=metric) - return plot.plot_delta( - g_prevs, - g_dict, - metric=metric, - name=self.name, - train_prev=self.train_prev, - ) - elif mode == "diagonal": - _, g_dict = self.avg_by_prevs(metric=metric + "_score") - f_dict = {k: v for k, v in g_dict.items() if k != "ref"} - referece = g_dict["ref"] - return plot.plot_diagonal( - referece, - f_dict, - metric=metric, - name=self.name, - train_prev=self.train_prev, - ) - elif mode == "shift": - g_prevs, g_dict = self.avg_by_prevs(metric=metric) - return plot.plot_shift( - g_prevs, - g_dict, - metric=metric, - name=self.name, - train_prev=self.train_prev, - ) + def get_plots( + self, + modes=["delta", "diagonal", "shift"], + metric="acc", + estimators=None, + conf="default", + stdev=False, + ) -> Path: + pps = [] + for mode in modes: + pp = [] + if mode == "delta": + f_dict = self.avg_by_prevs(metric=metric, estimators=estimators) + _pp0 = plot.plot_delta( + self.cprevs, + f_dict, + metric=metric, + name=conf, + train_prev=self.train_prev, + fit_scores=self.fit_scores, + ) + pp = [(mode, _pp0)] + if stdev: + fs_dict = self.stdev_by_prevs(metric=metric, estimators=estimators) + _pp1 = plot.plot_delta( + self.cprevs, + f_dict, + metric=metric, + name=conf, + train_prev=self.train_prev, + fit_scores=self.fit_scores, + stdevs=fs_dict, + ) + pp.append((f"{mode}_stdev", _pp1)) + elif mode == "diagonal": + f_dict = { + col: np.concatenate([vals[bp] for bp in self.prevs]) + for col, vals in self.data( + metric=metric + "_score", estimators=estimators + ).items() + } + reference = f_dict["ref"] + f_dict = {k: v for k, v in f_dict.items() if k != "ref"} + _pp0 = plot.plot_diagonal( + reference, + f_dict, + metric=metric, + name=conf, + train_prev=self.train_prev, + ) + pp = [(mode, _pp0)] - def to_md(self, *metrics): - res = "" + elif mode == "shift": + s_prevs, s_dict = self.group_by_shift( + metric=metric, estimators=estimators + ) + _pp0 = plot.plot_shift( + np.around([(1.0 - p, p) for p in s_prevs], decimals=2), + { + col: np.array([np.mean(vals[sp]) for sp in s_prevs]) + for col, vals in s_dict.items() + }, + metric=metric, + name=conf, + train_prev=self.train_prev, + fit_scores=self.fit_scores, + ) + pp = [(mode, _pp0)] + + pps.extend(pp) + + return pps + + def to_md(self, conf="default", metric="acc", estimators=None, stdev=False): + res = f"## {int(np.around(self.train_prev, decimals=2)[1]*100)}% positives\n" res += fmt_line_md(f"train: {str(self.train_prev)}") res += fmt_line_md(f"validation: {str(self.valid_prev)}") for k, v in self.times.items(): res += fmt_line_md(f"{k}: {v:.3f}s") res += "\n" - for m in metrics: - res += self.get_dataframe(metric=m).to_html() + "\n\n" - op_delta = self.get_plot(mode="delta", metric=m) - res += f"![plot_delta]({str(op_delta.relative_to(env.OUT_DIR))})\n" - op_diag = self.get_plot(mode="diagonal", metric=m) - res += f"![plot_diagonal]({str(op_diag.relative_to(env.OUT_DIR))})\n" - op_shift = self.get_plot(mode="shift", metric=m) - res += f"![plot_shift]({str(op_shift.relative_to(env.OUT_DIR))})\n" + res += ( + self.get_dataframe(metric=metric, estimators=estimators).to_html() + "\n\n" + ) + plot_modes = ["delta", "diagonal", "shift"] + for mode, op in self.get_plots( + modes=plot_modes, + metric=metric, + estimators=estimators, + conf=conf, + stdev=stdev, + ): + res += f"![plot_{mode}]({op.relative_to(env.OUT_DIR).as_posix()})\n" return res - def merge(self, other): - if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)): - raise ValueError("other has not same base prevalences of self") - - inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys())) - if len(inters_keys) > 0: - raise ValueError(f"self and other have matching keys {str(inters_keys)}.") - - report = EvaluationReport() - report._prevs = self._prevs - report._dict = self._dict | other._dict - return report - - @staticmethod - def combine_reports(*args, name="default", train_prev=None, valid_prev=None): - er = args[0] - for r in args[1:]: - er = er.merge(r) - - er.name = name - er.train_prev = train_prev - er.valid_prev = valid_prev - return er - class DatasetReport: def __init__(self, name): self.name = name - self.ers: List[EvaluationReport] = [] + self._dict = None + self.crs: List[CompReport] = [] - def add(self, er: EvaluationReport): - self.ers.append(er) + @property + def cprevs(self): + return np.around([(1.0 - p, p) for p in self.prevs], decimals=2) - def to_md(self, *metrics): - res = f"{self.name}\n\n" - for er in self.ers: - res += f"{er.to_md(*metrics)}\n\n" + def add(self, cr: CompReport): + self.crs.append(cr) + + if self._dict is None: + self.prevs = cr.prevs + self._dict = { + col: {bp: vals[bp] for bp in self.prevs} + for col, vals in cr.data().items() + } + self.s_prevs, self.s_dict = cr.group_by_shift() + self.fit_scores = {k: [score] for k, score in cr.fit_scores.items()} + return + + cr_dict = cr.data() + both_prevs = np.array([self.prevs, cr.prevs]) + if not np.all(both_prevs == both_prevs[0, :]).all(): + raise ValueError("Comp report has incompatible base prevalences") + + for col, vals in cr_dict.items(): + if col not in self._dict: + self._dict[col] = {} + for bp in self.prevs: + if bp not in self._dict[col]: + self._dict[col][bp] = [] + self._dict[col][bp] = np.concatenate( + [self._dict[col][bp], cr_dict[col][bp]] + ) + + cr_s_prevs, cr_s_dict = cr.group_by_shift() + self.s_prevs = np.sort(np.unique(np.concatenate([self.s_prevs, cr_s_prevs]))) + + for col, vals in cr_s_dict.items(): + if col not in self.s_dict: + self.s_dict[col] = {} + for sp in cr_s_prevs: + if sp not in self.s_dict[col]: + self.s_dict[col][sp] = [] + self.s_dict[col][sp] = np.concatenate( + [self.s_dict[col][sp], cr_s_dict[col][sp]] + ) + + for k, score in cr.fit_scores.items(): + if k not in self.fit_scores: + self.fit_scores[k] = [] + self.fit_scores[k].append(score) + + def __add__(self, cr: CompReport): + self.add(cr) + return self + + def __iadd__(self, cr: CompReport): + self.add(cr) + return self + + def to_md(self, conf="default", metric="acc", estimators=[], stdev=False): + res = f"# {self.name}\n\n" + for cr in self.crs: + res += f"{cr.to_md(conf, metric=metric, estimators=estimators, stdev=stdev)}\n\n" + + f_dict = { + c1: v + for ((c0, c1), v) in self._dict.items() + if c0 == metric and c1 in estimators + } + s_avg_dict = { + col: np.array([np.mean(vals[sp]) for sp in self.s_prevs]) + for col, vals in { + c1: v + for ((c0, c1), v) in self.s_dict.items() + if c0 == metric and c1 in estimators + }.items() + } + avg_dict = { + col: np.array([np.mean(vals[bp]) for bp in self.prevs]) + for col, vals in f_dict.items() + } + if stdev: + stdev_dict = { + col: np.array([np.std(vals[bp]) for bp in self.prevs]) + for col, vals in f_dict.items() + } + all_dict = { + col: [np.mean(np.concatenate(list(vals.values())))] + for col, vals in f_dict.items() + } + df = pd.DataFrame( + {col: np.append(avg_dict[col], val) for col, val in all_dict.items()}, + index=self.prevs.tolist() + ["tot"], + columns=all_dict.keys(), + ) + + res += "## avg\n" + res += df.to_html() + "\n\n" + + delta_op = plot.plot_delta( + np.around([(1.0 - p, p) for p in self.prevs], decimals=2), + avg_dict, + metric=metric, + name=conf, + train_prev=None, + fit_scores={k: np.mean(vals) for k, vals in self.fit_scores.items()}, + ) + res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n" + + if stdev: + delta_stdev_op = plot.plot_delta( + np.around([(1.0 - p, p) for p in self.prevs], decimals=2), + avg_dict, + metric=metric, + name=conf, + train_prev=None, + fit_scores={k: np.mean(vals) for k, vals in self.fit_scores.items()}, + stdevs=stdev_dict, + ) + res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n" + + shift_op = plot.plot_shift( + np.around([(1.0 - p, p) for p in self.s_prevs], decimals=2), + s_avg_dict, + metric=metric, + name=conf, + train_prev=None, + fit_scores={k: np.mean(vals) for k, vals in self.fit_scores.items()}, + ) + res += f"![plot_shift]({shift_op.relative_to(env.OUT_DIR).as_posix()})\n" return res def __iter__(self): - return (er for er in self.ers) + return (cr for cr in self.crs) diff --git a/quacc/main.py b/quacc/main.py index 61d699f..4a35dee 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -1,49 +1,59 @@ -import os -import shutil -from pathlib import Path +import logging as log +import traceback +from sys import platform import quacc.evaluation.comp as comp from quacc.dataset import Dataset -from quacc.environ import env +from quacc.environment import env +from quacc.utils import create_dataser_dir -def create_out_dir(dir_name): - base_out_dir = Path(env.OUT_DIR_NAME) - if not base_out_dir.exists(): - os.mkdir(base_out_dir) - dir_path = base_out_dir / dir_name - env.OUT_DIR = dir_path - shutil.rmtree(dir_path, ignore_errors=True) - os.mkdir(dir_path) - plot_dir_path = dir_path / "plot" - env.PLOT_OUT_DIR = plot_dir_path - os.mkdir(plot_dir_path) +def toast(): + if platform == "win32": + import win11toast + + win11toast.notify("Comp", "Completed Execution") def estimate_comparison(): - for conf in env: - create_out_dir(conf) + for conf in env.get_confs(): + create_dataser_dir(conf, update=env.DATASET_DIR_UPDATE) dataset = Dataset( env.DATASET_NAME, target=env.DATASET_TARGET, n_prevalences=env.DATASET_N_PREVS, ) - output_path = env.OUT_DIR / f"{dataset.name}.md" try: dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS) - for m in env.METRICS: - output_path = env.OUT_DIR / f"{conf}_{m}.md" - with open(output_path, "w") as f: - f.write(dr.to_md(m)) + for plot_conf in env.get_plot_confs(): + for m in env.METRICS: + output_path = env.OUT_DIR / f"{plot_conf}_{m}.md" + with open(output_path, "w") as f: + f.write( + dr.to_md( + conf=plot_conf, + metric=m, + estimators=env.PLOT_ESTIMATORS, + stdev=env.PLOT_STDEV, + ) + ) except Exception as e: - print(f"Configuration {conf} failed. {e}") + log.error(f"Configuration {conf} failed. Exception: {e}") + traceback(e) # print(df.to_latex(float_format="{:.4f}".format)) # print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format)) def main(): + log.basicConfig( + filename="quacc.log", + filemode="a", + format="%(asctime)s| %(levelname)s: %(message)s", + datefmt="%d/%m/%y %H:%M:%S", + ) estimate_comparison() + toast() if __name__ == "__main__": diff --git a/quacc/plot.py b/quacc/plot.py index 2aa195d..2b65c36 100644 --- a/quacc/plot.py +++ b/quacc/plot.py @@ -1,54 +1,40 @@ from pathlib import Path +import matplotlib import matplotlib.pyplot as plt import numpy as np +from cycler import cycler -from quacc.environ import env +from quacc.environment import env + +matplotlib.use("agg") def _get_markers(n: int): - ls = [ - "o", - "v", - "x", - "+", - "s", - "D", - "p", - "h", - "*", - "^", - "1", - "2", - "3", - "4", - "X", - ">", - "<", - ".", - "P", - "d", - ] + ls = "ovx+sDph*^1234X><.Pd" if n > len(ls): ls = ls * (n / len(ls) + 1) - return ls[:n] + return list(ls)[:n] def plot_delta( base_prevs, dict_vals, *, + stdevs=None, pos_class=1, metric="acc", name="default", train_prev=None, + fit_scores=None, legend=True, ) -> Path: + _base_title = "delta_stdev" if stdevs is not None else "delta" if train_prev is not None: t_prev_pos = int(round(train_prev[pos_class] * 100)) - title = f"delta_{name}_{t_prev_pos}_{metric}" + title = f"{_base_title}_{name}_{t_prev_pos}_{metric}" else: - title = f"delta_{name}_{metric}" + title = f"{_base_title}_{name}_avg_{metric}" fig, ax = plt.subplots() ax.set_aspect("auto") @@ -58,24 +44,37 @@ def plot_delta( cm = plt.get_cmap("tab10") if NUM_COLORS > 10: cm = plt.get_cmap("tab20") - ax.set_prop_cycle( - color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)], - ) + cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) - base_prevs = [bp[pos_class] for bp in base_prevs] - for method, deltas in dict_vals.items(): - avg = np.array([np.mean(d, axis=-1) for d in deltas]) - # std = np.array([np.std(d, axis=-1) for d in deltas]) + base_prevs = base_prevs[:, pos_class] + for (method, deltas), _cy in zip(dict_vals.items(), cy): ax.plot( base_prevs, - avg, + deltas, label=method, + color=_cy["color"], linestyle="-", marker="o", markersize=3, zorder=2, ) - # ax.fill_between(base_prevs, avg - std, avg + std, alpha=0.25) + if stdevs is not None: + stdev = stdevs[method] + ax.fill_between( + base_prevs, + deltas - stdev, + deltas + stdev, + color=_cy["color"], + alpha=0.25, + ) + if fit_scores is not None and method in fit_scores: + ax.plot( + base_prevs, + np.repeat(fit_scores[method], base_prevs.shape[0]), + color=_cy["color"], + linestyle="--", + markersize=0, + ) ax.set(xlabel="test prevalence", ylabel=metric, title=title) @@ -106,42 +105,62 @@ def plot_diagonal( fig, ax = plt.subplots() ax.set_aspect("auto") ax.grid() + ax.set_aspect("equal") NUM_COLORS = len(dict_vals) cm = plt.get_cmap("tab10") - ax.set_prop_cycle( - marker=_get_markers(NUM_COLORS) * 2, - color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)] * 2, + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + cy = cycler( + color=[cm(i) for i in range(NUM_COLORS)], + marker=_get_markers(NUM_COLORS), ) reference = np.array(reference) x_ticks = np.unique(reference) x_ticks.sort() - for _, deltas in dict_vals.items(): - deltas = np.array(deltas) + for (_, deltas), _cy in zip(dict_vals.items(), cy): ax.plot( reference, deltas, + color=_cy["color"], linestyle="None", + marker=_cy["marker"], markersize=3, zorder=2, + alpha=0.25, ) - for method, deltas in dict_vals.items(): - deltas = np.array(deltas) - x_interp = x_ticks[[0, -1]] - y_interp = np.interp(x_interp, reference, deltas) + # ensure limits are equal for both axes + _alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1) + _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)]) + ax.set(xlim=tuple(_lims), ylim=tuple(_lims)) + + for (method, deltas), _cy in zip(dict_vals.items(), cy): + slope, interc = np.polyfit(reference, deltas, 1) + y_lr = np.array([slope * x + interc for x in _lims]) ax.plot( - x_interp, - y_interp, + _lims, + y_lr, label=method, + color=_cy["color"], linestyle="-", markersize="0", zorder=1, ) - ax.set(xlabel="test prevalence", ylabel=metric, title=title) + # plot reference line + ax.plot( + _lims, + _lims, + color="black", + linestyle="--", + markersize=0, + zorder=1, + ) + + ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title) if legend: ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) @@ -151,62 +170,55 @@ def plot_diagonal( def plot_shift( - base_prevs, - dict_vals, + shift_prevs, + shift_dict, *, pos_class=1, metric="acc", name="default", train_prev=None, + fit_scores=None, legend=True, ) -> Path: - if train_prev is None: - raise AttributeError("train_prev cannot be None.") - - train_prev = train_prev[pos_class] - t_prev_pos = int(round(train_prev * 100)) - title = f"shift_{name}_{t_prev_pos}_{metric}" + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"shift_{name}_{t_prev_pos}_{metric}" + else: + title = f"shift_{name}_avg_{metric}" fig, ax = plt.subplots() ax.set_aspect("auto") ax.grid() - NUM_COLORS = len(dict_vals) + NUM_COLORS = len(shift_dict) cm = plt.get_cmap("tab10") if NUM_COLORS > 10: cm = plt.get_cmap("tab20") - ax.set_prop_cycle( - color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)], - ) - - base_prevs = np.around( - [abs(bp[pos_class] - train_prev) for bp in base_prevs], decimals=2 - ) - for method, deltas in dict_vals.items(): - delta_bins = {} - for bp, delta in zip(base_prevs, deltas): - if bp not in delta_bins: - delta_bins[bp] = [] - delta_bins[bp].append(delta) - - bp_unique, delta_avg = zip( - *sorted( - {k: np.mean(v) for k, v in delta_bins.items()}.items(), - key=lambda db: db[0], - ) - ) + cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) + shift_prevs = shift_prevs[:, pos_class] + for (method, shifts), _cy in zip(shift_dict.items(), cy): ax.plot( - bp_unique, - delta_avg, + shift_prevs, + shifts, label=method, + color=_cy["color"], linestyle="-", marker="o", markersize=3, zorder=2, ) - ax.set(xlabel="test prevalence", ylabel=metric, title=title) + if fit_scores is not None and method in fit_scores: + ax.plot( + shift_prevs, + np.repeat(fit_scores[method], shift_prevs.shape[0]), + color=_cy["color"], + linestyle="--", + markersize=0, + ) + + ax.set(xlabel="dataset shift", ylabel=metric, title=title) if legend: ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) diff --git a/quacc/utils.py b/quacc/utils.py index d2b61f0..7989154 100644 --- a/quacc/utils.py +++ b/quacc/utils.py @@ -1,7 +1,12 @@ import functools +import os +import shutil +from pathlib import Path import pandas as pd +from quacc.environment import env + def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame: if len(dfs) < 1: @@ -32,3 +37,23 @@ def avg_group_report(df: pd.DataFrame) -> pd.DataFrame: def fmt_line_md(s): return f"> {s} \n" + + +def create_dataser_dir(dir_name, update=False): + base_out_dir = Path(env.OUT_DIR_NAME) + if not base_out_dir.exists(): + os.mkdir(base_out_dir) + + dataset_dir = base_out_dir / dir_name + env.OUT_DIR = dataset_dir + if update: + if not dataset_dir.exists(): + os.mkdir(dataset_dir) + else: + shutil.rmtree(dataset_dir, ignore_errors=True) + os.mkdir(dataset_dir) + + plot_dir_path = dataset_dir / "plot" + env.PLOT_OUT_DIR = plot_dir_path + if not plot_dir_path.exists(): + os.mkdir(plot_dir_path)