launching quanet

This commit is contained in:
Alejandro Moreo Fernandez 2021-01-20 09:01:04 +01:00
parent 482e4453a8
commit f69eb59eb8
2 changed files with 5 additions and 6 deletions

View File

@ -41,10 +41,10 @@ def quantification_models():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Running QuaNet in {device}')
#yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params
param_mod_sel={'sample_size':settings.SAMPLE_SIZE, 'n_prevpoints':21, 'n_repetitions':5}
yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
#yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
# yield 'epaccmraeptr', EPACC(newLR(), param_grid=lr_params, optim='mrae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
# yield 'epaccmae', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='mae', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
# yield 'epaccmrae', EPACC(newLR(), param_grid=lr_params, optim='mrae', policy='mrae', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None

View File

@ -186,6 +186,7 @@ class ACC(AggregativeQuantifier):
to estimate the parameters
:return: self
"""
assert val_split is not None, 'val_split cannot be set to None'
if isinstance(val_split, int):
# kFCV estimation of parameters
y, y_ = [], []
@ -269,6 +270,7 @@ class PACC(AggregativeProbabilisticQuantifier):
to estimate the parameters
:return: self
"""
assert val_split is not None, 'val_split cannot be set to None'
if isinstance(val_split, int):
# kFCV estimation of parameters
y, y_ = [], []
@ -385,6 +387,7 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
indicating the validation set itself
:return: self
"""
assert val_split is not None, 'val_split cannot be set to None'
self._check_binary(data, self.__class__.__name__)
self.learner, validation = training_helper(
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
@ -570,10 +573,6 @@ class OneVsAll(AggregativeQuantifier):
def _delayed_binary_posteriors(self, c, X):
return self.dict_binary_quantifiers[c].posterior_probabilities(X)
#def _delayed_binary_quantify(self, c, X):
# the estimation for the positive class prevalence
# return self.dict_binary_quantifiers[c].quantify(X)[1]
def _delayed_binary_aggregate(self, c, classif_predictions):
# the estimation for the positive class prevalence
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1]