From 13c4e357df72c3c89d6c3329033f4cef4589b138 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Tue, 31 Oct 2023 14:53:00 +0100 Subject: [PATCH] environment reworked --- quacc/environment.py | 120 +++++++++++++++++++++++++++---------------- quacc/main.py | 34 +++++++----- 2 files changed, 96 insertions(+), 58 deletions(-) diff --git a/quacc/environment.py b/quacc/environment.py index d5447e3..f22f162 100644 --- a/quacc/environment.py +++ b/quacc/environment.py @@ -1,46 +1,53 @@ -import yaml +import collections as C +import copy +from typing import Any -defalut_env = { - "DATASET_NAME": "rcv1", - "DATASET_TARGET": "CCAT", - "METRICS": ["acc", "f1"], - "COMP_ESTIMATORS": [], - "PLOT_ESTIMATORS": [], - "PLOT_STDEV": False, - "DATASET_N_PREVS": 9, - "DATASET_PREVS": None, - "OUT_DIR_NAME": "output", - "OUT_DIR": None, - "PLOT_DIR_NAME": "plot", - "PLOT_OUT_DIR": None, - "DATASET_DIR_UPDATE": False, - "PROTOCOL_N_PREVS": 21, - "PROTOCOL_REPEATS": 100, - "SAMPLE_SIZE": 1000, -} +import yaml class environ: _instance = None + _default_env = { + "DATASET_NAME": None, + "DATASET_TARGET": None, + "METRICS": [], + "COMP_ESTIMATORS": [], + "DATASET_N_PREVS": 9, + "DATASET_PREVS": None, + "OUT_DIR_NAME": "output", + "OUT_DIR": None, + "PLOT_DIR_NAME": "plot", + "PLOT_OUT_DIR": None, + "DATASET_DIR_UPDATE": False, + "PROTOCOL_N_PREVS": 21, + "PROTOCOL_REPEATS": 100, + "SAMPLE_SIZE": 1000, + "PLOT_ESTIMATORS": [], + "PLOT_STDEV": False, + } + _keys = list(_default_env.keys()) - def __init__(self, **kwargs): + def __init__(self): self.exec = [] self.confs = [] - self._default = kwargs - self.__setdict(kwargs) self.load_conf() + self._stack = C.deque([self.__getdict()]) def __setdict(self, d): for k, v in d.items(): - self.__setattr__(k, v) - if len(self.PLOT_ESTIMATORS) == 0: - self.PLOT_ESTIMATORS = self.COMP_ESTIMATORS + super().__setattr__(k, v) - def __class_getitem__(cls, k): - env = cls.get() - return env.__getattribute__(k) + def __getdict(self): + return {k: self.__getattribute__(k) for k in environ._keys} + + def __setattr__(self, __name: str, __value: Any) -> None: + if __name in environ._keys: + self._stack[-1][__name] = __value + super().__setattr__(__name, __value) def load_conf(self): + self.__setdict(environ._default_env) + with open("conf.yaml", "r") as f: confs = yaml.safe_load(f)["exec"] @@ -50,31 +57,30 @@ class environ: _estimators = _estimators.union(set(pc["PLOT_ESTIMATORS"])) _global["COMP_ESTIMATORS"] = list(_estimators) + self.__setdict(_global) + + self.confs = confs["confs"] self.plot_confs = confs["plot_confs"] - for dataset in confs["datasets"]: - self.confs.append(_global | dataset) - def get_confs(self): + self._stack.append(None) for _conf in self.confs: - self.__setdict(self._default) + self._stack.pop() + self.__setdict(self._stack[-1]) self.__setdict(_conf) - if "DATASET_TARGET" not in _conf: - self.DATASET_TARGET = None + self._stack.append(self.__getdict()) - name = self.DATASET_NAME - if self.DATASET_TARGET is not None: - name += f"_{self.DATASET_TARGET}" - name += f"_{self.DATASET_N_PREVS}prevs" + yield copy.deepcopy(self._stack[-1]) - yield name + self._stack.pop() def get_plot_confs(self): + self._stack.append(None) for k, pc in self.plot_confs.items(): - if "PLOT_ESTIMATORS" in pc: - self.PLOT_ESTIMATORS = pc["PLOT_ESTIMATORS"] - if "PLOT_STDEV" in pc: - self.PLOT_STDEV = pc["PLOT_STDEV"] + self._stack.pop() + self.__setdict(self._stack[-1]) + self.__setdict(pc) + self._stack.append(self.__getdict()) name = self.DATASET_NAME if self.DATASET_TARGET is not None: @@ -82,5 +88,31 @@ class environ: name += f"_{k}" yield name + self._stack.pop() -env = environ(**defalut_env) + @property + def current(self): + return copy.deepcopy(self.__getdict()) + + +env = environ() + +if __name__ == "__main__": + stack = C.deque() + stack.append(-1) + + def __gen(stack: C.deque): + stack.append(None) + for i in range(5): + stack.pop() + stack.append(i) + yield stack[-1] + + stack.pop() + + print(stack) + + for i in __gen(stack): + print(stack, i) + + print(stack) diff --git a/quacc/main.py b/quacc/main.py index b69c5d1..b7fd62a 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -18,30 +18,36 @@ def toast(): def estimate_comparison(): log = Logger.logger() for conf in env.get_confs(): - create_dataser_dir(conf, update=env.DATASET_DIR_UPDATE) dataset = Dataset( env.DATASET_NAME, target=env.DATASET_TARGET, n_prevalences=env.DATASET_N_PREVS, prevs=env.DATASET_PREVS, ) + create_dataser_dir(dataset.name, update=env.DATASET_DIR_UPDATE) try: dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS) - for plot_conf in env.get_plot_confs(): - for m in env.METRICS: - output_path = env.OUT_DIR / f"{plot_conf}_{m}.md" - with open(output_path, "w") as f: - f.write( - dr.to_md( - conf=plot_conf, - metric=m, - estimators=env.PLOT_ESTIMATORS, - stdev=env.PLOT_STDEV, - ) - ) except Exception as e: - log.error(f"Configuration {conf} failed. Exception: {e}") + log.error(f"Evaluation over {dataset.name} failed. Exception: {e}") traceback(e) + for plot_conf in env.get_plot_confs(): + for m in env.METRICS: + log.debug(env.current) + output_path = env.OUT_DIR / f"{plot_conf}_{m}.md" + try: + _repr = dr.to_md( + conf=plot_conf, + metric=m, + estimators=env.PLOT_ESTIMATORS, + stdev=env.PLOT_STDEV, + ) + with open(output_path, "w") as f: + f.write(_repr) + except Exception as e: + log.error( + f"Failed while saving configuration {plot_conf} of {dataset.name}. Exception: {e}" + ) + traceback(e) # print(df.to_latex(float_format="{:.4f}".format)) # print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))