forked from moreo/QuaPy
fixing fit_learner=False case in QuaNet
This commit is contained in:
parent
8239947746
commit
a1cdc9ef43
4
TODO.txt
4
TODO.txt
|
@ -3,6 +3,7 @@ Packaging:
|
|||
Documentation with sphinx
|
||||
Document methods with paper references
|
||||
unit-tests
|
||||
clean wiki_examples!
|
||||
|
||||
Refactor:
|
||||
==========================================
|
||||
|
@ -31,6 +32,9 @@ OVR I believe is currently tied to aggregative methods. We should provide a gene
|
|||
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
|
||||
Add random seed management to support replicability (see temp_seed in util.py).
|
||||
GridSearchQ is not trully parallelized. It only parallelizes on the predictions.
|
||||
In the context of a quantifier (e.g., QuaNet or CC), the parameters of the learner should be prefixed with "estimator__",
|
||||
in QuaNet this is resolved with a __check_params_colision, but this should be improved. It might be cumbersome to
|
||||
impose the "estimator__" prefix for, e.g., quantifiers like CC though... This should be changed everywhere...
|
||||
|
||||
Improvements:
|
||||
==========================================
|
||||
|
|
|
@ -18,7 +18,7 @@ from quapy.util import EarlyStop
|
|||
class NeuralClassifierTrainer:
|
||||
|
||||
def __init__(self,
|
||||
net, # TextClassifierNet
|
||||
net: 'TextClassifierNet',
|
||||
lr=1e-3,
|
||||
weight_decay=0,
|
||||
patience=10,
|
||||
|
|
|
@ -138,12 +138,16 @@ def training_helper(learner,
|
|||
if isinstance(learner, BaseQuantifier):
|
||||
learner.fit(train)
|
||||
else:
|
||||
learner.fit(train.instances, train.labels)
|
||||
learner.fit(*train.Xy)
|
||||
else:
|
||||
if ensure_probabilistic:
|
||||
if not hasattr(learner, 'predict_proba'):
|
||||
raise AssertionError('error: the learner cannot be calibrated since fit_learner is set to False')
|
||||
unused = data
|
||||
unused = None
|
||||
if val_split.__class__.__name__ == LabelledCollection.__name__:
|
||||
unused = val_split
|
||||
if data is not None:
|
||||
unused = unused+data
|
||||
|
||||
return learner, unused
|
||||
|
||||
|
@ -193,6 +197,8 @@ class ACC(AggregativeQuantifier):
|
|||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
if isinstance(val_split, int):
|
||||
assert fit_learner == True, \
|
||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||
# kFCV estimation of parameters
|
||||
y, y_ = [], []
|
||||
kfcv = StratifiedKFold(n_splits=val_split)
|
||||
|
@ -280,6 +286,8 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
val_split = self.val_split
|
||||
|
||||
if isinstance(val_split, int):
|
||||
assert fit_learner == True, \
|
||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||
# kFCV estimation of parameters
|
||||
y, y_ = [], []
|
||||
kfcv = StratifiedKFold(n_splits=val_split)
|
||||
|
@ -529,6 +537,8 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
|
|||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
if isinstance(val_split, int):
|
||||
assert fit_learner == True, \
|
||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||
# kFCV estimation of parameters
|
||||
y, probabilities = [], []
|
||||
kfcv = StratifiedKFold(n_splits=val_split)
|
||||
|
|
|
@ -69,19 +69,19 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
:return: self
|
||||
"""
|
||||
self._classes_ = data.classes_
|
||||
classifier_data, unused_data = data.split_stratified(0.4)
|
||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||
|
||||
print('Classifier data: ', len(classifier_data))
|
||||
print('Q-Training data: ', len(train_data))
|
||||
print('Q-Valid data: ', len(valid_data))
|
||||
os.makedirs(self.checkpointdir, exist_ok=True)
|
||||
|
||||
if fit_learner:
|
||||
classifier_data, unused_data = data.split_stratified(0.4)
|
||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||
self.learner.fit(*classifier_data.Xy)
|
||||
else:
|
||||
classifier_data = None
|
||||
train_data, valid_data = data.split_stratified(0.66)
|
||||
|
||||
# estimate the hard and soft stats tpr and fpr of the classifier
|
||||
self.tr_prev = data.prevalence()
|
||||
|
||||
self.learner.fit(*classifier_data.Xy)
|
||||
|
||||
# compute the posterior probabilities of the instances
|
||||
valid_posteriors = self.learner.predict_proba(valid_data.instances)
|
||||
train_posteriors = self.learner.predict_proba(train_data.instances)
|
||||
|
@ -132,7 +132,6 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
|
||||
f'for epoch {early_stop.best_epoch}')
|
||||
self.quanet.load_state_dict(torch.load(checkpoint))
|
||||
#self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
|
||||
break
|
||||
|
||||
return self
|
||||
|
@ -144,9 +143,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
predictions = posteriors if quantifier.probabilistic else label_predictions
|
||||
prevs_estim.extend(quantifier.aggregate(predictions))
|
||||
|
||||
# add the class-conditional predictions P(y'i|yj) from ACC and PACC
|
||||
# prevs_estim.extend(self.quantifiers['acc'].Pte_cond_estim_.flatten())
|
||||
# prevs_estim.extend(self.quantifiers['pacc'].Pte_cond_estim_.flatten())
|
||||
# there is no real need for adding static estims like the TPR or FPR from training since those are constant
|
||||
|
||||
return prevs_estim
|
||||
|
||||
|
@ -164,8 +161,6 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
|
||||
def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
|
||||
mse_loss = MSELoss()
|
||||
# prevpoints = F.get_nprevpoints_approximation(iterations, self.quanet.n_classes)
|
||||
# iterations = F.num_prevalence_combinations(prevpoints, self.quanet.n_classes)
|
||||
|
||||
self.quanet.train(mode=train)
|
||||
losses = []
|
||||
|
@ -176,11 +171,10 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
with qp.util.temp_seed(0):
|
||||
sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
|
||||
else:
|
||||
# sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
|
||||
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in F.uniform_simplex_sampling(data.n_classes, iterations)]
|
||||
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in
|
||||
F.uniform_simplex_sampling(data.n_classes, iterations)]
|
||||
pbar = tqdm(sampling_index_gen, total=iterations) if train else sampling_index_gen
|
||||
|
||||
rand_it_show = np.random.randint(iterations)
|
||||
for it, index in enumerate(pbar):
|
||||
sample_data = data.sampling_from_index(index)
|
||||
sample_posteriors = posteriors[index]
|
||||
|
@ -218,27 +212,14 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} '
|
||||
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
|
||||
|
||||
# if it==rand_it_show:
|
||||
# print()
|
||||
# print('='*100)
|
||||
# print('Training: ' if train else 'Validation:')
|
||||
# print('=' * 100)
|
||||
# print('True: ', ptrue.cpu().numpy().flatten())
|
||||
# print('Estim: ', phat.detach().cpu().numpy().flatten())
|
||||
# for pred, name in zip(np.asarray(quant_estims).reshape(-1,data.n_classes),
|
||||
# ['cc', 'acc', 'pcc', 'pacc', 'emq', 'Pte[acc]','','','Pte[pacc]','','']):
|
||||
# print(name, pred)
|
||||
|
||||
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return {**self.learner.get_params(), **self.quanet_params}
|
||||
|
||||
def set_params(self, **parameters):
|
||||
learner_params={}
|
||||
learner_params = {}
|
||||
for key, val in parameters.items():
|
||||
if key in self.quanet_params:
|
||||
self.quanet_params[key]=val
|
||||
self.quanet_params[key] = val
|
||||
else:
|
||||
learner_params[key] = val
|
||||
self.learner.set_params(**learner_params)
|
||||
|
|
Loading…
Reference in New Issue