from typing import List import numpy as np from quacc.evaluation import baseline, method, alt class CompEstimatorFunc_: def __init__(self, ce): self.ce = ce def __getitem__(self, e: str | List[str]): if isinstance(e, str): return list(self.ce._CompEstimator__get(e).values())[0] elif isinstance(e, list): return list(self.ce._CompEstimator__get(e).values()) class CompEstimatorName_: def __init__(self, ce): self.ce = ce def __getitem__(self, e: str | List[str]): if isinstance(e, str): return list(self.ce._CompEstimator__get(e).keys())[0] elif isinstance(e, list): return list(self.ce._CompEstimator__get(e).keys()) def sort(self, e: List[str]): return list(self.ce._CompEstimator__get(e, get_ref=False).keys()) @property def all(self): return list(self.ce._CompEstimator__get("__all").keys()) @property def baselines(self): return list(self.ce._CompEstimator__get("__baselines").keys()) class CompEstimator: def __get(cls, e: str | List[str], get_ref=True): _dict = alt._alts | baseline._baselines | method._methods if isinstance(e, str) and e == "__all": e = list(_dict.keys()) if isinstance(e, str) and e == "__baselines": e = list(baseline._baselines.keys()) if isinstance(e, str): try: return {e: _dict[e]} except KeyError: raise KeyError(f"Invalid estimator: estimator {e} does not exist") elif isinstance(e, list) or isinstance(e, np.ndarray): _subtr = np.setdiff1d(e, list(_dict.keys())) if len(_subtr) > 0: raise KeyError( f"Invalid estimator: estimator {_subtr[0]} does not exist" ) e_fun = {k: fun for k, fun in _dict.items() if k in e} if get_ref and "ref" not in e: e_fun["ref"] = _dict["ref"] elif not get_ref and "ref" in e: del e_fun["ref"] return e_fun @property def name(self): return CompEstimatorName_(self) @property def func(self): return CompEstimatorFunc_(self) CE = CompEstimator() _renames = { "bin_sld_lr": "(2x2)_SLD_LR", "mul_sld_lr": "(1x4)_SLD_LR", "m3w_sld_lr": "(1x3)_SLD_LR", "d_bin_sld_lr": "d_(2x2)_SLD_LR", "d_mul_sld_lr": "d_(1x4)_SLD_LR", "d_m3w_sld_lr": "d_(1x3)_SLD_LR", "d_bin_sld_rbf": "(2x2)_SLD_RBF", "d_mul_sld_rbf": "(1x4)_SLD_RBF", "d_m3w_sld_rbf": "(1x3)_SLD_RBF", # "sld_lr_gs": "MS_SLD_LR", "sld_lr_gs": "QuAcc(SLD)", "bin_kde_lr": "(2x2)_KDEy_LR", "mul_kde_lr": "(1x4)_KDEy_LR", "m3w_kde_lr": "(1x3)_KDEy_LR", "d_bin_kde_lr": "d_(2x2)_KDEy_LR", "d_mul_kde_lr": "d_(1x4)_KDEy_LR", "d_m3w_kde_lr": "d_(1x3)_KDEy_LR", "bin_cc_lr": "(2x2)_CC_LR", "mul_cc_lr": "(1x4)_CC_LR", "m3w_cc_lr": "(1x3)_CC_LR", # "kde_lr_gs": "MS_KDEy_LR", "kde_lr_gs": "QuAcc(KDEy)", # "cc_lr_gs": "MS_CC_LR", "cc_lr_gs": "QuAcc(CC)", "atc_mc": "ATC", "doc": "DoC", "mandoline": "Mandoline", "rca": "RCA", "rca_star": "RCA*", "naive": "Naive", }