environment reworked
This commit is contained in:
parent
f7c8f69351
commit
13c4e357df
|
@ -1,46 +1,53 @@
|
|||
import yaml
|
||||
import collections as C
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
defalut_env = {
|
||||
"DATASET_NAME": "rcv1",
|
||||
"DATASET_TARGET": "CCAT",
|
||||
"METRICS": ["acc", "f1"],
|
||||
"COMP_ESTIMATORS": [],
|
||||
"PLOT_ESTIMATORS": [],
|
||||
"PLOT_STDEV": False,
|
||||
"DATASET_N_PREVS": 9,
|
||||
"DATASET_PREVS": None,
|
||||
"OUT_DIR_NAME": "output",
|
||||
"OUT_DIR": None,
|
||||
"PLOT_DIR_NAME": "plot",
|
||||
"PLOT_OUT_DIR": None,
|
||||
"DATASET_DIR_UPDATE": False,
|
||||
"PROTOCOL_N_PREVS": 21,
|
||||
"PROTOCOL_REPEATS": 100,
|
||||
"SAMPLE_SIZE": 1000,
|
||||
}
|
||||
import yaml
|
||||
|
||||
|
||||
class environ:
|
||||
_instance = None
|
||||
_default_env = {
|
||||
"DATASET_NAME": None,
|
||||
"DATASET_TARGET": None,
|
||||
"METRICS": [],
|
||||
"COMP_ESTIMATORS": [],
|
||||
"DATASET_N_PREVS": 9,
|
||||
"DATASET_PREVS": None,
|
||||
"OUT_DIR_NAME": "output",
|
||||
"OUT_DIR": None,
|
||||
"PLOT_DIR_NAME": "plot",
|
||||
"PLOT_OUT_DIR": None,
|
||||
"DATASET_DIR_UPDATE": False,
|
||||
"PROTOCOL_N_PREVS": 21,
|
||||
"PROTOCOL_REPEATS": 100,
|
||||
"SAMPLE_SIZE": 1000,
|
||||
"PLOT_ESTIMATORS": [],
|
||||
"PLOT_STDEV": False,
|
||||
}
|
||||
_keys = list(_default_env.keys())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self):
|
||||
self.exec = []
|
||||
self.confs = []
|
||||
self._default = kwargs
|
||||
self.__setdict(kwargs)
|
||||
self.load_conf()
|
||||
self._stack = C.deque([self.__getdict()])
|
||||
|
||||
def __setdict(self, d):
|
||||
for k, v in d.items():
|
||||
self.__setattr__(k, v)
|
||||
if len(self.PLOT_ESTIMATORS) == 0:
|
||||
self.PLOT_ESTIMATORS = self.COMP_ESTIMATORS
|
||||
super().__setattr__(k, v)
|
||||
|
||||
def __class_getitem__(cls, k):
|
||||
env = cls.get()
|
||||
return env.__getattribute__(k)
|
||||
def __getdict(self):
|
||||
return {k: self.__getattribute__(k) for k in environ._keys}
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
if __name in environ._keys:
|
||||
self._stack[-1][__name] = __value
|
||||
super().__setattr__(__name, __value)
|
||||
|
||||
def load_conf(self):
|
||||
self.__setdict(environ._default_env)
|
||||
|
||||
with open("conf.yaml", "r") as f:
|
||||
confs = yaml.safe_load(f)["exec"]
|
||||
|
||||
|
@ -50,31 +57,30 @@ class environ:
|
|||
_estimators = _estimators.union(set(pc["PLOT_ESTIMATORS"]))
|
||||
_global["COMP_ESTIMATORS"] = list(_estimators)
|
||||
|
||||
self.__setdict(_global)
|
||||
|
||||
self.confs = confs["confs"]
|
||||
self.plot_confs = confs["plot_confs"]
|
||||
|
||||
for dataset in confs["datasets"]:
|
||||
self.confs.append(_global | dataset)
|
||||
|
||||
def get_confs(self):
|
||||
self._stack.append(None)
|
||||
for _conf in self.confs:
|
||||
self.__setdict(self._default)
|
||||
self._stack.pop()
|
||||
self.__setdict(self._stack[-1])
|
||||
self.__setdict(_conf)
|
||||
if "DATASET_TARGET" not in _conf:
|
||||
self.DATASET_TARGET = None
|
||||
self._stack.append(self.__getdict())
|
||||
|
||||
name = self.DATASET_NAME
|
||||
if self.DATASET_TARGET is not None:
|
||||
name += f"_{self.DATASET_TARGET}"
|
||||
name += f"_{self.DATASET_N_PREVS}prevs"
|
||||
yield copy.deepcopy(self._stack[-1])
|
||||
|
||||
yield name
|
||||
self._stack.pop()
|
||||
|
||||
def get_plot_confs(self):
|
||||
self._stack.append(None)
|
||||
for k, pc in self.plot_confs.items():
|
||||
if "PLOT_ESTIMATORS" in pc:
|
||||
self.PLOT_ESTIMATORS = pc["PLOT_ESTIMATORS"]
|
||||
if "PLOT_STDEV" in pc:
|
||||
self.PLOT_STDEV = pc["PLOT_STDEV"]
|
||||
self._stack.pop()
|
||||
self.__setdict(self._stack[-1])
|
||||
self.__setdict(pc)
|
||||
self._stack.append(self.__getdict())
|
||||
|
||||
name = self.DATASET_NAME
|
||||
if self.DATASET_TARGET is not None:
|
||||
|
@ -82,5 +88,31 @@ class environ:
|
|||
name += f"_{k}"
|
||||
yield name
|
||||
|
||||
self._stack.pop()
|
||||
|
||||
env = environ(**defalut_env)
|
||||
@property
|
||||
def current(self):
|
||||
return copy.deepcopy(self.__getdict())
|
||||
|
||||
|
||||
env = environ()
|
||||
|
||||
if __name__ == "__main__":
|
||||
stack = C.deque()
|
||||
stack.append(-1)
|
||||
|
||||
def __gen(stack: C.deque):
|
||||
stack.append(None)
|
||||
for i in range(5):
|
||||
stack.pop()
|
||||
stack.append(i)
|
||||
yield stack[-1]
|
||||
|
||||
stack.pop()
|
||||
|
||||
print(stack)
|
||||
|
||||
for i in __gen(stack):
|
||||
print(stack, i)
|
||||
|
||||
print(stack)
|
||||
|
|
|
@ -18,30 +18,36 @@ def toast():
|
|||
def estimate_comparison():
|
||||
log = Logger.logger()
|
||||
for conf in env.get_confs():
|
||||
create_dataser_dir(conf, update=env.DATASET_DIR_UPDATE)
|
||||
dataset = Dataset(
|
||||
env.DATASET_NAME,
|
||||
target=env.DATASET_TARGET,
|
||||
n_prevalences=env.DATASET_N_PREVS,
|
||||
prevs=env.DATASET_PREVS,
|
||||
)
|
||||
create_dataser_dir(dataset.name, update=env.DATASET_DIR_UPDATE)
|
||||
try:
|
||||
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
|
||||
for plot_conf in env.get_plot_confs():
|
||||
for m in env.METRICS:
|
||||
output_path = env.OUT_DIR / f"{plot_conf}_{m}.md"
|
||||
with open(output_path, "w") as f:
|
||||
f.write(
|
||||
dr.to_md(
|
||||
conf=plot_conf,
|
||||
metric=m,
|
||||
estimators=env.PLOT_ESTIMATORS,
|
||||
stdev=env.PLOT_STDEV,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Configuration {conf} failed. Exception: {e}")
|
||||
log.error(f"Evaluation over {dataset.name} failed. Exception: {e}")
|
||||
traceback(e)
|
||||
for plot_conf in env.get_plot_confs():
|
||||
for m in env.METRICS:
|
||||
log.debug(env.current)
|
||||
output_path = env.OUT_DIR / f"{plot_conf}_{m}.md"
|
||||
try:
|
||||
_repr = dr.to_md(
|
||||
conf=plot_conf,
|
||||
metric=m,
|
||||
estimators=env.PLOT_ESTIMATORS,
|
||||
stdev=env.PLOT_STDEV,
|
||||
)
|
||||
with open(output_path, "w") as f:
|
||||
f.write(_repr)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed while saving configuration {plot_conf} of {dataset.name}. Exception: {e}"
|
||||
)
|
||||
traceback(e)
|
||||
|
||||
# print(df.to_latex(float_format="{:.4f}".format))
|
||||
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
|
||||
|
|
Loading…
Reference in New Issue