113 lines
3.2 KiB
Python
113 lines
3.2 KiB
Python
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from quacc.evaluation import baseline, method, alt
|
|
|
|
|
|
class CompEstimatorFunc_:
|
|
def __init__(self, ce):
|
|
self.ce = ce
|
|
|
|
def __getitem__(self, e: str | List[str]):
|
|
if isinstance(e, str):
|
|
return list(self.ce._CompEstimator__get(e).values())[0]
|
|
elif isinstance(e, list):
|
|
return list(self.ce._CompEstimator__get(e).values())
|
|
|
|
|
|
class CompEstimatorName_:
|
|
def __init__(self, ce):
|
|
self.ce = ce
|
|
|
|
def __getitem__(self, e: str | List[str]):
|
|
if isinstance(e, str):
|
|
return list(self.ce._CompEstimator__get(e).keys())[0]
|
|
elif isinstance(e, list):
|
|
return list(self.ce._CompEstimator__get(e).keys())
|
|
|
|
def sort(self, e: List[str]):
|
|
return list(self.ce._CompEstimator__get(e, get_ref=False).keys())
|
|
|
|
@property
|
|
def all(self):
|
|
return list(self.ce._CompEstimator__get("__all").keys())
|
|
|
|
@property
|
|
def baselines(self):
|
|
return list(self.ce._CompEstimator__get("__baselines").keys())
|
|
|
|
|
|
class CompEstimator:
|
|
def __get(cls, e: str | List[str], get_ref=True):
|
|
_dict = alt._alts | baseline._baselines | method._methods
|
|
|
|
if isinstance(e, str) and e == "__all":
|
|
e = list(_dict.keys())
|
|
if isinstance(e, str) and e == "__baselines":
|
|
e = list(baseline._baselines.keys())
|
|
|
|
if isinstance(e, str):
|
|
try:
|
|
return {e: _dict[e]}
|
|
except KeyError:
|
|
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
|
|
elif isinstance(e, list) or isinstance(e, np.ndarray):
|
|
_subtr = np.setdiff1d(e, list(_dict.keys()))
|
|
if len(_subtr) > 0:
|
|
raise KeyError(
|
|
f"Invalid estimator: estimator {_subtr[0]} does not exist"
|
|
)
|
|
|
|
e_fun = {k: fun for k, fun in _dict.items() if k in e}
|
|
if get_ref and "ref" not in e:
|
|
e_fun["ref"] = _dict["ref"]
|
|
elif not get_ref and "ref" in e:
|
|
del e_fun["ref"]
|
|
|
|
return e_fun
|
|
|
|
@property
|
|
def name(self):
|
|
return CompEstimatorName_(self)
|
|
|
|
@property
|
|
def func(self):
|
|
return CompEstimatorFunc_(self)
|
|
|
|
|
|
CE = CompEstimator()
|
|
|
|
_renames = {
|
|
"bin_sld_lr": "(2x2)_SLD_LR",
|
|
"mul_sld_lr": "(1x4)_SLD_LR",
|
|
"m3w_sld_lr": "(1x3)_SLD_LR",
|
|
"d_bin_sld_lr": "d_(2x2)_SLD_LR",
|
|
"d_mul_sld_lr": "d_(1x4)_SLD_LR",
|
|
"d_m3w_sld_lr": "d_(1x3)_SLD_LR",
|
|
"d_bin_sld_rbf": "(2x2)_SLD_RBF",
|
|
"d_mul_sld_rbf": "(1x4)_SLD_RBF",
|
|
"d_m3w_sld_rbf": "(1x3)_SLD_RBF",
|
|
# "sld_lr_gs": "MS_SLD_LR",
|
|
"sld_lr_gs": "QuAcc(SLD)",
|
|
"bin_kde_lr": "(2x2)_KDEy_LR",
|
|
"mul_kde_lr": "(1x4)_KDEy_LR",
|
|
"m3w_kde_lr": "(1x3)_KDEy_LR",
|
|
"d_bin_kde_lr": "d_(2x2)_KDEy_LR",
|
|
"d_mul_kde_lr": "d_(1x4)_KDEy_LR",
|
|
"d_m3w_kde_lr": "d_(1x3)_KDEy_LR",
|
|
"bin_cc_lr": "(2x2)_CC_LR",
|
|
"mul_cc_lr": "(1x4)_CC_LR",
|
|
"m3w_cc_lr": "(1x3)_CC_LR",
|
|
# "kde_lr_gs": "MS_KDEy_LR",
|
|
"kde_lr_gs": "QuAcc(KDEy)",
|
|
# "cc_lr_gs": "MS_CC_LR",
|
|
"cc_lr_gs": "QuAcc(CC)",
|
|
"atc_mc": "ATC",
|
|
"doc": "DoC",
|
|
"mandoline": "Mandoline",
|
|
"rca": "RCA",
|
|
"rca_star": "RCA*",
|
|
"naive": "Naive",
|
|
}
|