launching quanet
This commit is contained in:
parent
482e4453a8
commit
f69eb59eb8
|
@ -41,10 +41,10 @@ def quantification_models():
|
|||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f'Running QuaNet in {device}')
|
||||
#yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params
|
||||
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params
|
||||
|
||||
param_mod_sel={'sample_size':settings.SAMPLE_SIZE, 'n_prevpoints':21, 'n_repetitions':5}
|
||||
yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||
#yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||
# yield 'epaccmraeptr', EPACC(newLR(), param_grid=lr_params, optim='mrae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||
# yield 'epaccmae', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='mae', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||
# yield 'epaccmrae', EPACC(newLR(), param_grid=lr_params, optim='mrae', policy='mrae', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||
|
|
|
@ -186,6 +186,7 @@ class ACC(AggregativeQuantifier):
|
|||
to estimate the parameters
|
||||
:return: self
|
||||
"""
|
||||
assert val_split is not None, 'val_split cannot be set to None'
|
||||
if isinstance(val_split, int):
|
||||
# kFCV estimation of parameters
|
||||
y, y_ = [], []
|
||||
|
@ -269,6 +270,7 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
to estimate the parameters
|
||||
:return: self
|
||||
"""
|
||||
assert val_split is not None, 'val_split cannot be set to None'
|
||||
if isinstance(val_split, int):
|
||||
# kFCV estimation of parameters
|
||||
y, y_ = [], []
|
||||
|
@ -385,6 +387,7 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
|||
indicating the validation set itself
|
||||
:return: self
|
||||
"""
|
||||
assert val_split is not None, 'val_split cannot be set to None'
|
||||
self._check_binary(data, self.__class__.__name__)
|
||||
self.learner, validation = training_helper(
|
||||
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
||||
|
@ -570,10 +573,6 @@ class OneVsAll(AggregativeQuantifier):
|
|||
def _delayed_binary_posteriors(self, c, X):
|
||||
return self.dict_binary_quantifiers[c].posterior_probabilities(X)
|
||||
|
||||
#def _delayed_binary_quantify(self, c, X):
|
||||
# the estimation for the positive class prevalence
|
||||
# return self.dict_binary_quantifiers[c].quantify(X)[1]
|
||||
|
||||
def _delayed_binary_aggregate(self, c, classif_predictions):
|
||||
# the estimation for the positive class prevalence
|
||||
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1]
|
||||
|
|
Loading…
Reference in New Issue