From 2318541f496a72e535722a386feb12544a3d4151 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Sun, 26 Nov 2023 16:34:19 +0100 Subject: [PATCH] improved implementation, contextmanager added --- quacc/environment.py | 115 +++++++++++++------------------------------ 1 file changed, 34 insertions(+), 81 deletions(-) diff --git a/quacc/environment.py b/quacc/environment.py index 92a039a..d0fa683 100644 --- a/quacc/environment.py +++ b/quacc/environment.py @@ -1,12 +1,9 @@ -import collections as C -import copy -from typing import Any +from contextlib import contextmanager import yaml class environ: - _instance = None _default_env = { "DATASET_NAME": None, "DATASET_TARGET": None, @@ -22,97 +19,53 @@ class environ: "PROTOCOL_N_PREVS": 21, "PROTOCOL_REPEATS": 100, "SAMPLE_SIZE": 1000, - "PLOT_ESTIMATORS": [], + # "PLOT_ESTIMATORS": [], "PLOT_STDEV": False, + "_R_SEED": 0, } _keys = list(_default_env.keys()) def __init__(self): - self.exec = [] - self.confs = [] - self.load_conf() - self._stack = C.deque([self.__getdict()]) + self.__load_file() - def __setdict(self, d): - for k, v in d.items(): - super().__setattr__(k, v) - - def __getdict(self): - return {k: self.__getattribute__(k) for k in environ._keys} - - def __setattr__(self, __name: str, __value: Any) -> None: - if __name in environ._keys: - self._stack[-1][__name] = __value - super().__setattr__(__name, __value) - - def load_conf(self): - self.__setdict(environ._default_env) + def __load_file(self): + _state = environ._default_env.copy() with open("conf.yaml", "r") as f: confs = yaml.safe_load(f)["exec"] - _global = confs["global"] - _estimators = set() - for pc in confs["plot_confs"].values(): - _estimators = _estimators.union(set(pc["PLOT_ESTIMATORS"])) - _global["COMP_ESTIMATORS"] = list(_estimators) + _state = _state | confs["global"] + self.__setdict(_state) + self._confs = confs["confs"] - self.__setdict(_global) + def __setdict(self, d: dict): + for k, v in d.items(): + super().__setattr__(k, v) - self.confs = confs["confs"] - self.plot_confs = confs["plot_confs"] - - def get_confs(self): - self._stack.append(None) - for _conf in self.confs: - self._stack.pop() - self.__setdict(self._stack[-1]) - self.__setdict(_conf) - self._stack.append(self.__getdict()) - - yield copy.deepcopy(self._stack[-1]) - - self._stack.pop() - - def get_plot_confs(self): - self._stack.append(None) - for k, pc in self.plot_confs.items(): - self._stack.pop() - self.__setdict(self._stack[-1]) - self.__setdict(pc) - self._stack.append(self.__getdict()) - - name = self.DATASET_NAME - if self.DATASET_TARGET is not None: - name += f"_{self.DATASET_TARGET}" - name += f"_{k}" - yield name - - self._stack.pop() + def __getdict(self) -> dict: + return {k: self.__getattribute__(k) for k in environ._keys} @property - def current(self): - return copy.deepcopy(self.__getdict()) + def confs(self): + return self._confs.copy() + + @contextmanager + def load(self, conf): + __current = self.__getdict() + if conf is not None: + if isinstance(conf, dict): + self.__setdict(conf) + elif isinstance(conf, environ): + self.__setdict(conf.__getdict()) + try: + yield + finally: + self.__setdict(__current) + + def load_confs(self): + for c in self.confs: + with self.load(c): + yield c env = environ() - -if __name__ == "__main__": - stack = C.deque() - stack.append(-1) - - def __gen(stack: C.deque): - stack.append(None) - for i in range(5): - stack.pop() - stack.append(i) - yield stack[-1] - - stack.pop() - - print(stack) - - for i in __gen(stack): - print(stack, i) - - print(stack)