launching quanet
This commit is contained in:
parent
482e4453a8
commit
f69eb59eb8
|
@ -41,10 +41,10 @@ def quantification_models():
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
print(f'Running QuaNet in {device}')
|
print(f'Running QuaNet in {device}')
|
||||||
#yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params
|
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params
|
||||||
|
|
||||||
param_mod_sel={'sample_size':settings.SAMPLE_SIZE, 'n_prevpoints':21, 'n_repetitions':5}
|
param_mod_sel={'sample_size':settings.SAMPLE_SIZE, 'n_prevpoints':21, 'n_repetitions':5}
|
||||||
yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
#yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||||
# yield 'epaccmraeptr', EPACC(newLR(), param_grid=lr_params, optim='mrae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
# yield 'epaccmraeptr', EPACC(newLR(), param_grid=lr_params, optim='mrae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||||
# yield 'epaccmae', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='mae', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
# yield 'epaccmae', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='mae', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||||
# yield 'epaccmrae', EPACC(newLR(), param_grid=lr_params, optim='mrae', policy='mrae', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
# yield 'epaccmrae', EPACC(newLR(), param_grid=lr_params, optim='mrae', policy='mrae', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||||
|
|
|
@ -186,6 +186,7 @@ class ACC(AggregativeQuantifier):
|
||||||
to estimate the parameters
|
to estimate the parameters
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
|
assert val_split is not None, 'val_split cannot be set to None'
|
||||||
if isinstance(val_split, int):
|
if isinstance(val_split, int):
|
||||||
# kFCV estimation of parameters
|
# kFCV estimation of parameters
|
||||||
y, y_ = [], []
|
y, y_ = [], []
|
||||||
|
@ -269,6 +270,7 @@ class PACC(AggregativeProbabilisticQuantifier):
|
||||||
to estimate the parameters
|
to estimate the parameters
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
|
assert val_split is not None, 'val_split cannot be set to None'
|
||||||
if isinstance(val_split, int):
|
if isinstance(val_split, int):
|
||||||
# kFCV estimation of parameters
|
# kFCV estimation of parameters
|
||||||
y, y_ = [], []
|
y, y_ = [], []
|
||||||
|
@ -385,6 +387,7 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
indicating the validation set itself
|
indicating the validation set itself
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
|
assert val_split is not None, 'val_split cannot be set to None'
|
||||||
self._check_binary(data, self.__class__.__name__)
|
self._check_binary(data, self.__class__.__name__)
|
||||||
self.learner, validation = training_helper(
|
self.learner, validation = training_helper(
|
||||||
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
||||||
|
@ -570,10 +573,6 @@ class OneVsAll(AggregativeQuantifier):
|
||||||
def _delayed_binary_posteriors(self, c, X):
|
def _delayed_binary_posteriors(self, c, X):
|
||||||
return self.dict_binary_quantifiers[c].posterior_probabilities(X)
|
return self.dict_binary_quantifiers[c].posterior_probabilities(X)
|
||||||
|
|
||||||
#def _delayed_binary_quantify(self, c, X):
|
|
||||||
# the estimation for the positive class prevalence
|
|
||||||
# return self.dict_binary_quantifiers[c].quantify(X)[1]
|
|
||||||
|
|
||||||
def _delayed_binary_aggregate(self, c, classif_predictions):
|
def _delayed_binary_aggregate(self, c, classif_predictions):
|
||||||
# the estimation for the positive class prevalence
|
# the estimation for the positive class prevalence
|
||||||
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1]
|
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1]
|
||||||
|
|
Loading…
Reference in New Issue