fixing fit_learner=False case in QuaNet

This commit is contained in:
Alejandro Moreo Fernandez 2021-06-21 11:13:14 +02:00
parent 8239947746
commit a1cdc9ef43
4 changed files with 30 additions and 35 deletions

View File

@ -3,6 +3,7 @@ Packaging:
Documentation with sphinx Documentation with sphinx
Document methods with paper references Document methods with paper references
unit-tests unit-tests
clean wiki_examples!
Refactor: Refactor:
========================================== ==========================================
@ -31,6 +32,9 @@ OVR I believe is currently tied to aggregative methods. We should provide a gene
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
Add random seed management to support replicability (see temp_seed in util.py). Add random seed management to support replicability (see temp_seed in util.py).
GridSearchQ is not trully parallelized. It only parallelizes on the predictions. GridSearchQ is not trully parallelized. It only parallelizes on the predictions.
In the context of a quantifier (e.g., QuaNet or CC), the parameters of the learner should be prefixed with "estimator__",
in QuaNet this is resolved with a __check_params_colision, but this should be improved. It might be cumbersome to
impose the "estimator__" prefix for, e.g., quantifiers like CC though... This should be changed everywhere...
Improvements: Improvements:
========================================== ==========================================

View File

@ -18,7 +18,7 @@ from quapy.util import EarlyStop
class NeuralClassifierTrainer: class NeuralClassifierTrainer:
def __init__(self, def __init__(self,
net, # TextClassifierNet net: 'TextClassifierNet',
lr=1e-3, lr=1e-3,
weight_decay=0, weight_decay=0,
patience=10, patience=10,

View File

@ -138,12 +138,16 @@ def training_helper(learner,
if isinstance(learner, BaseQuantifier): if isinstance(learner, BaseQuantifier):
learner.fit(train) learner.fit(train)
else: else:
learner.fit(train.instances, train.labels) learner.fit(*train.Xy)
else: else:
if ensure_probabilistic: if ensure_probabilistic:
if not hasattr(learner, 'predict_proba'): if not hasattr(learner, 'predict_proba'):
raise AssertionError('error: the learner cannot be calibrated since fit_learner is set to False') raise AssertionError('error: the learner cannot be calibrated since fit_learner is set to False')
unused = data unused = None
if val_split.__class__.__name__ == LabelledCollection.__name__:
unused = val_split
if data is not None:
unused = unused+data
return learner, unused return learner, unused
@ -193,6 +197,8 @@ class ACC(AggregativeQuantifier):
if val_split is None: if val_split is None:
val_split = self.val_split val_split = self.val_split
if isinstance(val_split, int): if isinstance(val_split, int):
assert fit_learner == True, \
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
# kFCV estimation of parameters # kFCV estimation of parameters
y, y_ = [], [] y, y_ = [], []
kfcv = StratifiedKFold(n_splits=val_split) kfcv = StratifiedKFold(n_splits=val_split)
@ -280,6 +286,8 @@ class PACC(AggregativeProbabilisticQuantifier):
val_split = self.val_split val_split = self.val_split
if isinstance(val_split, int): if isinstance(val_split, int):
assert fit_learner == True, \
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
# kFCV estimation of parameters # kFCV estimation of parameters
y, y_ = [], [] y, y_ = [], []
kfcv = StratifiedKFold(n_splits=val_split) kfcv = StratifiedKFold(n_splits=val_split)
@ -529,6 +537,8 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
if val_split is None: if val_split is None:
val_split = self.val_split val_split = self.val_split
if isinstance(val_split, int): if isinstance(val_split, int):
assert fit_learner == True, \
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
# kFCV estimation of parameters # kFCV estimation of parameters
y, probabilities = [], [] y, probabilities = [], []
kfcv = StratifiedKFold(n_splits=val_split) kfcv = StratifiedKFold(n_splits=val_split)

View File

@ -69,19 +69,19 @@ class QuaNetTrainer(BaseQuantifier):
:return: self :return: self
""" """
self._classes_ = data.classes_ self._classes_ = data.classes_
classifier_data, unused_data = data.split_stratified(0.4)
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
print('Classifier data: ', len(classifier_data))
print('Q-Training data: ', len(train_data))
print('Q-Valid data: ', len(valid_data))
os.makedirs(self.checkpointdir, exist_ok=True) os.makedirs(self.checkpointdir, exist_ok=True)
if fit_learner:
classifier_data, unused_data = data.split_stratified(0.4)
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
self.learner.fit(*classifier_data.Xy)
else:
classifier_data = None
train_data, valid_data = data.split_stratified(0.66)
# estimate the hard and soft stats tpr and fpr of the classifier # estimate the hard and soft stats tpr and fpr of the classifier
self.tr_prev = data.prevalence() self.tr_prev = data.prevalence()
self.learner.fit(*classifier_data.Xy)
# compute the posterior probabilities of the instances # compute the posterior probabilities of the instances
valid_posteriors = self.learner.predict_proba(valid_data.instances) valid_posteriors = self.learner.predict_proba(valid_data.instances)
train_posteriors = self.learner.predict_proba(train_data.instances) train_posteriors = self.learner.predict_proba(train_data.instances)
@ -132,7 +132,6 @@ class QuaNetTrainer(BaseQuantifier):
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} ' print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
f'for epoch {early_stop.best_epoch}') f'for epoch {early_stop.best_epoch}')
self.quanet.load_state_dict(torch.load(checkpoint)) self.quanet.load_state_dict(torch.load(checkpoint))
#self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
break break
return self return self
@ -144,9 +143,7 @@ class QuaNetTrainer(BaseQuantifier):
predictions = posteriors if quantifier.probabilistic else label_predictions predictions = posteriors if quantifier.probabilistic else label_predictions
prevs_estim.extend(quantifier.aggregate(predictions)) prevs_estim.extend(quantifier.aggregate(predictions))
# add the class-conditional predictions P(y'i|yj) from ACC and PACC # there is no real need for adding static estims like the TPR or FPR from training since those are constant
# prevs_estim.extend(self.quantifiers['acc'].Pte_cond_estim_.flatten())
# prevs_estim.extend(self.quantifiers['pacc'].Pte_cond_estim_.flatten())
return prevs_estim return prevs_estim
@ -164,8 +161,6 @@ class QuaNetTrainer(BaseQuantifier):
def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train): def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
mse_loss = MSELoss() mse_loss = MSELoss()
# prevpoints = F.get_nprevpoints_approximation(iterations, self.quanet.n_classes)
# iterations = F.num_prevalence_combinations(prevpoints, self.quanet.n_classes)
self.quanet.train(mode=train) self.quanet.train(mode=train)
losses = [] losses = []
@ -176,11 +171,10 @@ class QuaNetTrainer(BaseQuantifier):
with qp.util.temp_seed(0): with qp.util.temp_seed(0):
sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints) sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
else: else:
# sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints) sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in F.uniform_simplex_sampling(data.n_classes, iterations)] F.uniform_simplex_sampling(data.n_classes, iterations)]
pbar = tqdm(sampling_index_gen, total=iterations) if train else sampling_index_gen pbar = tqdm(sampling_index_gen, total=iterations) if train else sampling_index_gen
rand_it_show = np.random.randint(iterations)
for it, index in enumerate(pbar): for it, index in enumerate(pbar):
sample_data = data.sampling_from_index(index) sample_data = data.sampling_from_index(index)
sample_posteriors = posteriors[index] sample_posteriors = posteriors[index]
@ -218,27 +212,14 @@ class QuaNetTrainer(BaseQuantifier):
f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} ' f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} '
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}') f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
# if it==rand_it_show:
# print()
# print('='*100)
# print('Training: ' if train else 'Validation:')
# print('=' * 100)
# print('True: ', ptrue.cpu().numpy().flatten())
# print('Estim: ', phat.detach().cpu().numpy().flatten())
# for pred, name in zip(np.asarray(quant_estims).reshape(-1,data.n_classes),
# ['cc', 'acc', 'pcc', 'pacc', 'emq', 'Pte[acc]','','','Pte[pacc]','','']):
# print(name, pred)
def get_params(self, deep=True): def get_params(self, deep=True):
return {**self.learner.get_params(), **self.quanet_params} return {**self.learner.get_params(), **self.quanet_params}
def set_params(self, **parameters): def set_params(self, **parameters):
learner_params={} learner_params = {}
for key, val in parameters.items(): for key, val in parameters.items():
if key in self.quanet_params: if key in self.quanet_params:
self.quanet_params[key]=val self.quanet_params[key] = val
else: else:
learner_params[key] = val learner_params[key] = val
self.learner.set_params(**learner_params) self.learner.set_params(**learner_params)