86 lines
2.3 KiB
Python
86 lines
2.3 KiB
Python
import yaml
|
|
|
|
defalut_env = {
|
|
"DATASET_NAME": "rcv1",
|
|
"DATASET_TARGET": "CCAT",
|
|
"METRICS": ["acc", "f1"],
|
|
"COMP_ESTIMATORS": [],
|
|
"PLOT_ESTIMATORS": [],
|
|
"PLOT_STDEV": False,
|
|
"DATASET_N_PREVS": 9,
|
|
"OUT_DIR_NAME": "output",
|
|
"OUT_DIR": None,
|
|
"PLOT_DIR_NAME": "plot",
|
|
"PLOT_OUT_DIR": None,
|
|
"DATASET_DIR_UPDATE": False,
|
|
"PROTOCOL_N_PREVS": 21,
|
|
"PROTOCOL_REPEATS": 100,
|
|
"SAMPLE_SIZE": 1000,
|
|
}
|
|
|
|
|
|
class environ:
|
|
_instance = None
|
|
|
|
def __init__(self, **kwargs):
|
|
self.exec = []
|
|
self.confs = []
|
|
self._default = kwargs
|
|
self.__setdict(kwargs)
|
|
self.load_conf()
|
|
|
|
def __setdict(self, d):
|
|
for k, v in d.items():
|
|
self.__setattr__(k, v)
|
|
if len(self.PLOT_ESTIMATORS) == 0:
|
|
self.PLOT_ESTIMATORS = self.COMP_ESTIMATORS
|
|
|
|
def __class_getitem__(cls, k):
|
|
env = cls.get()
|
|
return env.__getattribute__(k)
|
|
|
|
def load_conf(self):
|
|
with open("conf.yaml", "r") as f:
|
|
confs = yaml.safe_load(f)["exec"]
|
|
|
|
_global = confs["global"]
|
|
_estimators = set()
|
|
for pc in confs["plot_confs"].values():
|
|
_estimators = _estimators.union(set(pc["PLOT_ESTIMATORS"]))
|
|
_global["COMP_ESTIMATORS"] = list(_estimators)
|
|
|
|
self.plot_confs = confs["plot_confs"]
|
|
|
|
for dataset in confs["datasets"]:
|
|
self.confs.append(_global | dataset)
|
|
|
|
def get_confs(self):
|
|
for _conf in self.confs:
|
|
self.__setdict(self._default)
|
|
self.__setdict(_conf)
|
|
if "DATASET_TARGET" not in _conf:
|
|
self.DATASET_TARGET = None
|
|
|
|
name = self.DATASET_NAME
|
|
if self.DATASET_TARGET is not None:
|
|
name += f"_{self.DATASET_TARGET}"
|
|
name += f"_{self.DATASET_N_PREVS}prevs"
|
|
|
|
yield name
|
|
|
|
def get_plot_confs(self):
|
|
for k, pc in self.plot_confs.items():
|
|
if "PLOT_ESTIMATORS" in pc:
|
|
self.PLOT_ESTIMATORS = pc["PLOT_ESTIMATORS"]
|
|
if "PLOT_STDEV" in pc:
|
|
self.PLOT_STDEV = pc["PLOT_STDEV"]
|
|
|
|
name = self.DATASET_NAME
|
|
if self.DATASET_TARGET is not None:
|
|
name += f"_{self.DATASET_TARGET}"
|
|
name += f"_{k}"
|
|
yield name
|
|
|
|
|
|
env = environ(**defalut_env)
|