forked from moreo/QuaPy
fixing issue regarding fit_learner=False in QuaNetTrainer
This commit is contained in:
parent
f33abb5319
commit
b4aeaa97b7
|
@ -130,11 +130,15 @@ class LabelledCollection:
|
|||
yield self.uniform_sampling_index(sample_size)
|
||||
|
||||
def __add__(self, other):
|
||||
if issparse(self.instances) and issparse(other.instances):
|
||||
if other is None:
|
||||
return self
|
||||
elif issparse(self.instances) and issparse(other.instances):
|
||||
join_instances = vstack([self.instances, other.instances])
|
||||
elif isinstance(self.instances, list) and isinstance(other.instances, list):
|
||||
join_instances = self.instances + other.instances
|
||||
elif isinstance(self.instances, np.ndarray) and isinstance(other.instances, np.ndarray):
|
||||
print(self.instances.shape)
|
||||
print(other.instances.shape)
|
||||
join_instances = np.concatenate([self.instances, other.instances])
|
||||
else:
|
||||
raise NotImplementedError('unsupported operation for collection types')
|
||||
|
|
|
@ -125,7 +125,7 @@ def training_helper(learner,
|
|||
if not (0 < val_split < 1):
|
||||
raise ValueError(f'train/val split {val_split} out of range, must be in (0,1)')
|
||||
train, unused = data.split_stratified(train_prop=1 - val_split)
|
||||
elif val_split.__class__.__name__ == LabelledCollection.__name__: # isinstance(val_split, LabelledCollection):
|
||||
elif isinstance(val_split, LabelledCollection):
|
||||
train = data
|
||||
unused = val_split
|
||||
else:
|
||||
|
@ -144,10 +144,8 @@ def training_helper(learner,
|
|||
if not hasattr(learner, 'predict_proba'):
|
||||
raise AssertionError('error: the learner cannot be calibrated since fit_learner is set to False')
|
||||
unused = None
|
||||
if val_split.__class__.__name__ == LabelledCollection.__name__:
|
||||
if isinstance(val_split, LabelledCollection):
|
||||
unused = val_split
|
||||
if data is not None:
|
||||
unused = unused+data
|
||||
|
||||
return learner, unused
|
||||
|
||||
|
@ -307,19 +305,22 @@ class PACC(AggregativeProbabilisticQuantifier):
|
|||
# fit the learner on all data
|
||||
self.learner, _ = training_helper(self.learner, data, fit_learner, ensure_probabilistic=True,
|
||||
val_split=None)
|
||||
classes = data.classes_
|
||||
|
||||
else:
|
||||
self.learner, val_data = training_helper(
|
||||
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
||||
y_ = self.learner.predict_proba(val_data.instances)
|
||||
y = val_data.labels
|
||||
classes = val_data.classes_
|
||||
|
||||
self.pcc = PCC(self.learner)
|
||||
|
||||
# estimate the matrix with entry (i,j) being the estimate of P(yi|yj), that is, the probability that a
|
||||
# document that belongs to yj ends up being classified as belonging to yi
|
||||
confusion = np.empty(shape=(data.n_classes, data.n_classes))
|
||||
for i,class_ in enumerate(data.classes_):
|
||||
n_classes = len(classes)
|
||||
confusion = np.empty(shape=(n_classes, n_classes))
|
||||
for i, class_ in enumerate(classes):
|
||||
confusion[i] = y_[y == class_].mean(axis=0)
|
||||
|
||||
self.Pte_cond_estim_ = confusion.T
|
||||
|
|
|
@ -91,12 +91,14 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
train_data.instances = self.learner.transform(train_data.instances)
|
||||
|
||||
self.quantifiers = {
|
||||
'cc': CC(self.learner).fit(classifier_data, fit_learner=False),
|
||||
'acc': ACC(self.learner).fit(classifier_data, fit_learner=False, val_split=valid_data),
|
||||
'pcc': PCC(self.learner).fit(classifier_data, fit_learner=False),
|
||||
'pacc': PACC(self.learner).fit(classifier_data, fit_learner=False, val_split=valid_data),
|
||||
'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False),
|
||||
'cc': CC(self.learner).fit(None, fit_learner=False),
|
||||
'acc': ACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
||||
'pcc': PCC(self.learner).fit(None, fit_learner=False),
|
||||
'pacc': PACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
||||
# 'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False),
|
||||
}
|
||||
if classifier_data is not None:
|
||||
self.quantifiers['emq'] = EMQ(self.learner).fit(classifier_data, fit_learner=False)
|
||||
|
||||
self.status = {
|
||||
'tr-loss': -1,
|
||||
|
|
Loading…
Reference in New Issue