forked from moreo/QuaPy
bugfix in NeuralClassifierTrainer; it was only configured to work well in binary problems
This commit is contained in:
parent
8e14bbc527
commit
e40c409609
|
@ -42,7 +42,7 @@ class NeuralClassifierTrainer:
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
batch_size_test=512,
|
batch_size_test=512,
|
||||||
padding_length=300,
|
padding_length=300,
|
||||||
device='cpu',
|
device='cuda',
|
||||||
checkpointpath='../checkpoint/classifier_net.dat'):
|
checkpointpath='../checkpoint/classifier_net.dat'):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -62,7 +62,6 @@ class NeuralClassifierTrainer:
|
||||||
}
|
}
|
||||||
self.learner_hyperparams = self.net.get_params()
|
self.learner_hyperparams = self.net.get_params()
|
||||||
self.checkpointpath = checkpointpath
|
self.checkpointpath = checkpointpath
|
||||||
self.classes_ = np.asarray([0, 1])
|
|
||||||
|
|
||||||
print(f'[NeuralNetwork running on {device}]')
|
print(f'[NeuralNetwork running on {device}]')
|
||||||
os.makedirs(Path(checkpointpath).parent, exist_ok=True)
|
os.makedirs(Path(checkpointpath).parent, exist_ok=True)
|
||||||
|
@ -174,6 +173,7 @@ class NeuralClassifierTrainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
train, val = LabelledCollection(instances, labels).split_stratified(1-val_split)
|
train, val = LabelledCollection(instances, labels).split_stratified(1-val_split)
|
||||||
|
self.classes_ = train.classes_
|
||||||
opt = self.trainer_hyperparams
|
opt = self.trainer_hyperparams
|
||||||
checkpoint = self.checkpointpath
|
checkpoint = self.checkpointpath
|
||||||
self.reset_net_params(self.vocab_size, train.n_classes)
|
self.reset_net_params(self.vocab_size, train.n_classes)
|
||||||
|
|
|
@ -184,7 +184,7 @@ class IndexTransformer:
|
||||||
|
|
||||||
def _index(self, documents):
|
def _index(self, documents):
|
||||||
vocab = self.vocabulary_.copy()
|
vocab = self.vocabulary_.copy()
|
||||||
return [[vocab.prevalence(word, self.unk) for word in self.analyzer(doc)] for doc in tqdm(documents, 'indexing')]
|
return [[vocab.get(word, self.unk) for word in self.analyzer(doc)] for doc in tqdm(documents, 'indexing')]
|
||||||
|
|
||||||
def fit_transform(self, X, n_jobs=-1):
|
def fit_transform(self, X, n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -282,6 +282,7 @@ class ACC(AggregativeQuantifier):
|
||||||
"""
|
"""
|
||||||
if val_split is None:
|
if val_split is None:
|
||||||
val_split = self.val_split
|
val_split = self.val_split
|
||||||
|
classes = data.classes_
|
||||||
if isinstance(val_split, int):
|
if isinstance(val_split, int):
|
||||||
assert fit_learner == True, \
|
assert fit_learner == True, \
|
||||||
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
|
||||||
|
@ -300,6 +301,7 @@ class ACC(AggregativeQuantifier):
|
||||||
y = np.concatenate(y)
|
y = np.concatenate(y)
|
||||||
y_ = np.concatenate(y_)
|
y_ = np.concatenate(y_)
|
||||||
class_count = data.counts()
|
class_count = data.counts()
|
||||||
|
classes = data.classes_
|
||||||
|
|
||||||
# fit the learner on all data
|
# fit the learner on all data
|
||||||
self.learner, _ = _training_helper(self.learner, data, fit_learner, val_split=None)
|
self.learner, _ = _training_helper(self.learner, data, fit_learner, val_split=None)
|
||||||
|
@ -308,10 +310,11 @@ class ACC(AggregativeQuantifier):
|
||||||
self.learner, val_data = _training_helper(self.learner, data, fit_learner, val_split=val_split)
|
self.learner, val_data = _training_helper(self.learner, data, fit_learner, val_split=val_split)
|
||||||
y_ = self.learner.predict(val_data.instances)
|
y_ = self.learner.predict(val_data.instances)
|
||||||
y = val_data.labels
|
y = val_data.labels
|
||||||
|
classes = val_data.classes_
|
||||||
|
|
||||||
self.cc = CC(self.learner)
|
self.cc = CC(self.learner)
|
||||||
|
|
||||||
self.Pte_cond_estim_ = self.getPteCondEstim(data.classes_, y, y_)
|
self.Pte_cond_estim_ = self.getPteCondEstim(classes, y, y_)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -82,6 +82,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
assert hasattr(learner, 'predict_proba'), \
|
assert hasattr(learner, 'predict_proba'), \
|
||||||
f'the learner {learner.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
|
f'the learner {learner.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
|
||||||
f'since it does not implement the method "predict_proba"'
|
f'since it does not implement the method "predict_proba"'
|
||||||
|
assert sample_size is not None, 'sample_size cannot be None'
|
||||||
self.learner = learner
|
self.learner = learner
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
self.n_epochs = n_epochs
|
self.n_epochs = n_epochs
|
||||||
|
|
Loading…
Reference in New Issue