fixing fit_learner=False case in QuaNet
This commit is contained in:
parent
8239947746
commit
a1cdc9ef43
4
TODO.txt
4
TODO.txt
|
@ -3,6 +3,7 @@ Packaging:
|
||||||
Documentation with sphinx
|
Documentation with sphinx
|
||||||
Document methods with paper references
|
Document methods with paper references
|
||||||
unit-tests
|
unit-tests
|
||||||
|
clean wiki_examples!
|
||||||
|
|
||||||
Refactor:
|
Refactor:
|
||||||
==========================================
|
==========================================
|
||||||
|
@ -31,6 +32,9 @@ OVR I believe is currently tied to aggregative methods. We should provide a gene
|
||||||
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
|
Currently, being "binary" only adds one checker; we should figure out how to impose the check to be automatically performed
|
||||||
Add random seed management to support replicability (see temp_seed in util.py).
|
Add random seed management to support replicability (see temp_seed in util.py).
|
||||||
GridSearchQ is not trully parallelized. It only parallelizes on the predictions.
|
GridSearchQ is not trully parallelized. It only parallelizes on the predictions.
|
||||||
|
In the context of a quantifier (e.g., QuaNet or CC), the parameters of the learner should be prefixed with "estimator__",
|
||||||
|
in QuaNet this is resolved with a __check_params_colision, but this should be improved. It might be cumbersome to
|
||||||
|
impose the "estimator__" prefix for, e.g., quantifiers like CC though... This should be changed everywhere...
|
||||||
|
|
||||||
Improvements:
|
Improvements:
|
||||||
==========================================
|
==========================================
|
||||||
|
|
|
@ -18,7 +18,7 @@ from quapy.util import EarlyStop
|
||||||
class NeuralClassifierTrainer:
|
class NeuralClassifierTrainer:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
net, # TextClassifierNet
|
net: 'TextClassifierNet',
|
||||||
lr=1e-3,
|
lr=1e-3,
|
||||||
weight_decay=0,
|
weight_decay=0,
|
||||||
patience=10,
|
patience=10,
|
||||||
|
|
|
@ -138,12 +138,16 @@ def training_helper(learner,
|
||||||
if isinstance(learner, BaseQuantifier):
|
if isinstance(learner, BaseQuantifier):
|
||||||
learner.fit(train)
|
learner.fit(train)
|
||||||
else:
|
else:
|
||||||
learner.fit(train.instances, train.labels)
|
learner.fit(*train.Xy)
|
||||||
else:
|
else:
|
||||||
if ensure_probabilistic:
|
if ensure_probabilistic:
|
||||||
if not hasattr(learner, 'predict_proba'):
|
if not hasattr(learner, 'predict_proba'):
|
||||||
raise AssertionError('error: the learner cannot be calibrated since fit_learner is set to False')
|
raise AssertionError('error: the learner cannot be calibrated since fit_learner is set to False')
|
||||||
unused = data
|
unused = None
|
||||||
|
if val_split.__class__.__name__ == LabelledCollection.__name__:
|
||||||
|
unused = val_split
|
||||||
|
if data is not None:
|
||||||
|
unused = unused+data
|
||||||
|
|
||||||
return learner, unused
|
return learner, unused
|
||||||
|
|
||||||
|
@ -193,6 +197,8 @@ class ACC(AggregativeQuantifier):
|
||||||
if val_split is None:
|
if val_split is None:
|
||||||
val_split = self.val_split
|
val_split = self.val_split
|
||||||
if isinstance(val_split, int):
|
if isinstance(val_split, int):
|
||||||
|
assert fit_learner == True, \
|
||||||
|
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||||
# kFCV estimation of parameters
|
# kFCV estimation of parameters
|
||||||
y, y_ = [], []
|
y, y_ = [], []
|
||||||
kfcv = StratifiedKFold(n_splits=val_split)
|
kfcv = StratifiedKFold(n_splits=val_split)
|
||||||
|
@ -280,6 +286,8 @@ class PACC(AggregativeProbabilisticQuantifier):
|
||||||
val_split = self.val_split
|
val_split = self.val_split
|
||||||
|
|
||||||
if isinstance(val_split, int):
|
if isinstance(val_split, int):
|
||||||
|
assert fit_learner == True, \
|
||||||
|
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||||
# kFCV estimation of parameters
|
# kFCV estimation of parameters
|
||||||
y, y_ = [], []
|
y, y_ = [], []
|
||||||
kfcv = StratifiedKFold(n_splits=val_split)
|
kfcv = StratifiedKFold(n_splits=val_split)
|
||||||
|
@ -529,6 +537,8 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
|
||||||
if val_split is None:
|
if val_split is None:
|
||||||
val_split = self.val_split
|
val_split = self.val_split
|
||||||
if isinstance(val_split, int):
|
if isinstance(val_split, int):
|
||||||
|
assert fit_learner == True, \
|
||||||
|
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||||
# kFCV estimation of parameters
|
# kFCV estimation of parameters
|
||||||
y, probabilities = [], []
|
y, probabilities = [], []
|
||||||
kfcv = StratifiedKFold(n_splits=val_split)
|
kfcv = StratifiedKFold(n_splits=val_split)
|
||||||
|
|
|
@ -69,19 +69,19 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
self._classes_ = data.classes_
|
self._classes_ = data.classes_
|
||||||
classifier_data, unused_data = data.split_stratified(0.4)
|
|
||||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
|
||||||
|
|
||||||
print('Classifier data: ', len(classifier_data))
|
|
||||||
print('Q-Training data: ', len(train_data))
|
|
||||||
print('Q-Valid data: ', len(valid_data))
|
|
||||||
os.makedirs(self.checkpointdir, exist_ok=True)
|
os.makedirs(self.checkpointdir, exist_ok=True)
|
||||||
|
|
||||||
|
if fit_learner:
|
||||||
|
classifier_data, unused_data = data.split_stratified(0.4)
|
||||||
|
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||||
|
self.learner.fit(*classifier_data.Xy)
|
||||||
|
else:
|
||||||
|
classifier_data = None
|
||||||
|
train_data, valid_data = data.split_stratified(0.66)
|
||||||
|
|
||||||
# estimate the hard and soft stats tpr and fpr of the classifier
|
# estimate the hard and soft stats tpr and fpr of the classifier
|
||||||
self.tr_prev = data.prevalence()
|
self.tr_prev = data.prevalence()
|
||||||
|
|
||||||
self.learner.fit(*classifier_data.Xy)
|
|
||||||
|
|
||||||
# compute the posterior probabilities of the instances
|
# compute the posterior probabilities of the instances
|
||||||
valid_posteriors = self.learner.predict_proba(valid_data.instances)
|
valid_posteriors = self.learner.predict_proba(valid_data.instances)
|
||||||
train_posteriors = self.learner.predict_proba(train_data.instances)
|
train_posteriors = self.learner.predict_proba(train_data.instances)
|
||||||
|
@ -132,7 +132,6 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
|
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
|
||||||
f'for epoch {early_stop.best_epoch}')
|
f'for epoch {early_stop.best_epoch}')
|
||||||
self.quanet.load_state_dict(torch.load(checkpoint))
|
self.quanet.load_state_dict(torch.load(checkpoint))
|
||||||
#self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
@ -144,9 +143,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
predictions = posteriors if quantifier.probabilistic else label_predictions
|
predictions = posteriors if quantifier.probabilistic else label_predictions
|
||||||
prevs_estim.extend(quantifier.aggregate(predictions))
|
prevs_estim.extend(quantifier.aggregate(predictions))
|
||||||
|
|
||||||
# add the class-conditional predictions P(y'i|yj) from ACC and PACC
|
# there is no real need for adding static estims like the TPR or FPR from training since those are constant
|
||||||
# prevs_estim.extend(self.quantifiers['acc'].Pte_cond_estim_.flatten())
|
|
||||||
# prevs_estim.extend(self.quantifiers['pacc'].Pte_cond_estim_.flatten())
|
|
||||||
|
|
||||||
return prevs_estim
|
return prevs_estim
|
||||||
|
|
||||||
|
@ -164,8 +161,6 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
|
|
||||||
def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
|
def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
|
||||||
mse_loss = MSELoss()
|
mse_loss = MSELoss()
|
||||||
# prevpoints = F.get_nprevpoints_approximation(iterations, self.quanet.n_classes)
|
|
||||||
# iterations = F.num_prevalence_combinations(prevpoints, self.quanet.n_classes)
|
|
||||||
|
|
||||||
self.quanet.train(mode=train)
|
self.quanet.train(mode=train)
|
||||||
losses = []
|
losses = []
|
||||||
|
@ -176,11 +171,10 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
with qp.util.temp_seed(0):
|
with qp.util.temp_seed(0):
|
||||||
sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
|
sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
|
||||||
else:
|
else:
|
||||||
# sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
|
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in
|
||||||
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in F.uniform_simplex_sampling(data.n_classes, iterations)]
|
F.uniform_simplex_sampling(data.n_classes, iterations)]
|
||||||
pbar = tqdm(sampling_index_gen, total=iterations) if train else sampling_index_gen
|
pbar = tqdm(sampling_index_gen, total=iterations) if train else sampling_index_gen
|
||||||
|
|
||||||
rand_it_show = np.random.randint(iterations)
|
|
||||||
for it, index in enumerate(pbar):
|
for it, index in enumerate(pbar):
|
||||||
sample_data = data.sampling_from_index(index)
|
sample_data = data.sampling_from_index(index)
|
||||||
sample_posteriors = posteriors[index]
|
sample_posteriors = posteriors[index]
|
||||||
|
@ -218,27 +212,14 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} '
|
f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} '
|
||||||
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
|
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
|
||||||
|
|
||||||
# if it==rand_it_show:
|
|
||||||
# print()
|
|
||||||
# print('='*100)
|
|
||||||
# print('Training: ' if train else 'Validation:')
|
|
||||||
# print('=' * 100)
|
|
||||||
# print('True: ', ptrue.cpu().numpy().flatten())
|
|
||||||
# print('Estim: ', phat.detach().cpu().numpy().flatten())
|
|
||||||
# for pred, name in zip(np.asarray(quant_estims).reshape(-1,data.n_classes),
|
|
||||||
# ['cc', 'acc', 'pcc', 'pacc', 'emq', 'Pte[acc]','','','Pte[pacc]','','']):
|
|
||||||
# print(name, pred)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_params(self, deep=True):
|
def get_params(self, deep=True):
|
||||||
return {**self.learner.get_params(), **self.quanet_params}
|
return {**self.learner.get_params(), **self.quanet_params}
|
||||||
|
|
||||||
def set_params(self, **parameters):
|
def set_params(self, **parameters):
|
||||||
learner_params={}
|
learner_params = {}
|
||||||
for key, val in parameters.items():
|
for key, val in parameters.items():
|
||||||
if key in self.quanet_params:
|
if key in self.quanet_params:
|
||||||
self.quanet_params[key]=val
|
self.quanet_params[key] = val
|
||||||
else:
|
else:
|
||||||
learner_params[key] = val
|
learner_params[key] = val
|
||||||
self.learner.set_params(**learner_params)
|
self.learner.set_params(**learner_params)
|
||||||
|
|
Loading…
Reference in New Issue