fixing fit_learner=False case in QuaNet

This commit is contained in:
Alejandro Moreo Fernandez 2021-06-21 11:13:14 +02:00
parent 8239947746
commit a1cdc9ef43
4 changed files with 30 additions and 35 deletions

View File

@ -3,6 +3,7 @@ Packaging:
Documentation with sphinx
Document methods with paper references
unit-tests
clean wiki_examples!
Refactor:
==========================================
@ -31,6 +32,9 @@ OVR I believe is currently tied to aggregative methods. We should provide a gene
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
Add random seed management to support replicability (see temp_seed in util.py).
GridSearchQ is not trully parallelized. It only parallelizes on the predictions.
In the context of a quantifier (e.g., QuaNet or CC), the parameters of the learner should be prefixed with "estimator__",
in QuaNet this is resolved with a __check_params_colision, but this should be improved. It might be cumbersome to
impose the "estimator__" prefix for, e.g., quantifiers like CC though... This should be changed everywhere...
Improvements:
==========================================

View File

@ -18,7 +18,7 @@ from quapy.util import EarlyStop
class NeuralClassifierTrainer:
def __init__(self,
net, # TextClassifierNet
net: 'TextClassifierNet',
lr=1e-3,
weight_decay=0,
patience=10,

View File

@ -138,12 +138,16 @@ def training_helper(learner,
if isinstance(learner, BaseQuantifier):
learner.fit(train)
else:
learner.fit(train.instances, train.labels)
learner.fit(*train.Xy)
else:
if ensure_probabilistic:
if not hasattr(learner, 'predict_proba'):
raise AssertionError('error: the learner cannot be calibrated since fit_learner is set to False')
unused = data
unused = None
if val_split.__class__.__name__ == LabelledCollection.__name__:
unused = val_split
if data is not None:
unused = unused+data
return learner, unused
@ -193,6 +197,8 @@ class ACC(AggregativeQuantifier):
if val_split is None:
val_split = self.val_split
if isinstance(val_split, int):
assert fit_learner == True, \
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
# kFCV estimation of parameters
y, y_ = [], []
kfcv = StratifiedKFold(n_splits=val_split)
@ -280,6 +286,8 @@ class PACC(AggregativeProbabilisticQuantifier):
val_split = self.val_split
if isinstance(val_split, int):
assert fit_learner == True, \
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
# kFCV estimation of parameters
y, y_ = [], []
kfcv = StratifiedKFold(n_splits=val_split)
@ -529,6 +537,8 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
if val_split is None:
val_split = self.val_split
if isinstance(val_split, int):
assert fit_learner == True, \
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
# kFCV estimation of parameters
y, probabilities = [], []
kfcv = StratifiedKFold(n_splits=val_split)

View File

@ -69,19 +69,19 @@ class QuaNetTrainer(BaseQuantifier):
:return: self
"""
self._classes_ = data.classes_
classifier_data, unused_data = data.split_stratified(0.4)
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
print('Classifier data: ', len(classifier_data))
print('Q-Training data: ', len(train_data))
print('Q-Valid data: ', len(valid_data))
os.makedirs(self.checkpointdir, exist_ok=True)
if fit_learner:
classifier_data, unused_data = data.split_stratified(0.4)
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
self.learner.fit(*classifier_data.Xy)
else:
classifier_data = None
train_data, valid_data = data.split_stratified(0.66)
# estimate the hard and soft stats tpr and fpr of the classifier
self.tr_prev = data.prevalence()
self.learner.fit(*classifier_data.Xy)
# compute the posterior probabilities of the instances
valid_posteriors = self.learner.predict_proba(valid_data.instances)
train_posteriors = self.learner.predict_proba(train_data.instances)
@ -132,7 +132,6 @@ class QuaNetTrainer(BaseQuantifier):
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
f'for epoch {early_stop.best_epoch}')
self.quanet.load_state_dict(torch.load(checkpoint))
#self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
break
return self
@ -144,9 +143,7 @@ class QuaNetTrainer(BaseQuantifier):
predictions = posteriors if quantifier.probabilistic else label_predictions
prevs_estim.extend(quantifier.aggregate(predictions))
# add the class-conditional predictions P(y'i|yj) from ACC and PACC
# prevs_estim.extend(self.quantifiers['acc'].Pte_cond_estim_.flatten())
# prevs_estim.extend(self.quantifiers['pacc'].Pte_cond_estim_.flatten())
# there is no real need for adding static estims like the TPR or FPR from training since those are constant
return prevs_estim
@ -164,8 +161,6 @@ class QuaNetTrainer(BaseQuantifier):
def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
mse_loss = MSELoss()
# prevpoints = F.get_nprevpoints_approximation(iterations, self.quanet.n_classes)
# iterations = F.num_prevalence_combinations(prevpoints, self.quanet.n_classes)
self.quanet.train(mode=train)
losses = []
@ -176,11 +171,10 @@ class QuaNetTrainer(BaseQuantifier):
with qp.util.temp_seed(0):
sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
else:
# sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in F.uniform_simplex_sampling(data.n_classes, iterations)]
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in
F.uniform_simplex_sampling(data.n_classes, iterations)]
pbar = tqdm(sampling_index_gen, total=iterations) if train else sampling_index_gen
rand_it_show = np.random.randint(iterations)
for it, index in enumerate(pbar):
sample_data = data.sampling_from_index(index)
sample_posteriors = posteriors[index]
@ -218,27 +212,14 @@ class QuaNetTrainer(BaseQuantifier):
f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} '
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
# if it==rand_it_show:
# print()
# print('='*100)
# print('Training: ' if train else 'Validation:')
# print('=' * 100)
# print('True: ', ptrue.cpu().numpy().flatten())
# print('Estim: ', phat.detach().cpu().numpy().flatten())
# for pred, name in zip(np.asarray(quant_estims).reshape(-1,data.n_classes),
# ['cc', 'acc', 'pcc', 'pacc', 'emq', 'Pte[acc]','','','Pte[pacc]','','']):
# print(name, pred)
def get_params(self, deep=True):
return {**self.learner.get_params(), **self.quanet_params}
def set_params(self, **parameters):
learner_params={}
learner_params = {}
for key, val in parameters.items():
if key in self.quanet_params:
self.quanet_params[key]=val
self.quanet_params[key] = val
else:
learner_params[key] = val
self.learner.set_params(**learner_params)