1
0
Fork 0

bugfix in NeuralClassifierTrainer; it was only configured to work well in binary problems

This commit is contained in:
Alejandro Moreo Fernandez 2022-10-04 11:03:08 +02:00
parent 8e14bbc527
commit e40c409609
4 changed files with 8 additions and 4 deletions

View File

@ -42,7 +42,7 @@ class NeuralClassifierTrainer:
batch_size=64, batch_size=64,
batch_size_test=512, batch_size_test=512,
padding_length=300, padding_length=300,
device='cpu', device='cuda',
checkpointpath='../checkpoint/classifier_net.dat'): checkpointpath='../checkpoint/classifier_net.dat'):
super().__init__() super().__init__()
@ -62,7 +62,6 @@ class NeuralClassifierTrainer:
} }
self.learner_hyperparams = self.net.get_params() self.learner_hyperparams = self.net.get_params()
self.checkpointpath = checkpointpath self.checkpointpath = checkpointpath
self.classes_ = np.asarray([0, 1])
print(f'[NeuralNetwork running on {device}]') print(f'[NeuralNetwork running on {device}]')
os.makedirs(Path(checkpointpath).parent, exist_ok=True) os.makedirs(Path(checkpointpath).parent, exist_ok=True)
@ -174,6 +173,7 @@ class NeuralClassifierTrainer:
:return: :return:
""" """
train, val = LabelledCollection(instances, labels).split_stratified(1-val_split) train, val = LabelledCollection(instances, labels).split_stratified(1-val_split)
self.classes_ = train.classes_
opt = self.trainer_hyperparams opt = self.trainer_hyperparams
checkpoint = self.checkpointpath checkpoint = self.checkpointpath
self.reset_net_params(self.vocab_size, train.n_classes) self.reset_net_params(self.vocab_size, train.n_classes)

View File

@ -184,7 +184,7 @@ class IndexTransformer:
def _index(self, documents): def _index(self, documents):
vocab = self.vocabulary_.copy() vocab = self.vocabulary_.copy()
return [[vocab.prevalence(word, self.unk) for word in self.analyzer(doc)] for doc in tqdm(documents, 'indexing')] return [[vocab.get(word, self.unk) for word in self.analyzer(doc)] for doc in tqdm(documents, 'indexing')]
def fit_transform(self, X, n_jobs=-1): def fit_transform(self, X, n_jobs=-1):
""" """

View File

@ -282,6 +282,7 @@ class ACC(AggregativeQuantifier):
""" """
if val_split is None: if val_split is None:
val_split = self.val_split val_split = self.val_split
classes = data.classes_
if isinstance(val_split, int): if isinstance(val_split, int):
assert fit_learner == True, \ assert fit_learner == True, \
'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False' 'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
@ -300,6 +301,7 @@ class ACC(AggregativeQuantifier):
y = np.concatenate(y) y = np.concatenate(y)
y_ = np.concatenate(y_) y_ = np.concatenate(y_)
class_count = data.counts() class_count = data.counts()
classes = data.classes_
# fit the learner on all data # fit the learner on all data
self.learner, _ = _training_helper(self.learner, data, fit_learner, val_split=None) self.learner, _ = _training_helper(self.learner, data, fit_learner, val_split=None)
@ -308,10 +310,11 @@ class ACC(AggregativeQuantifier):
self.learner, val_data = _training_helper(self.learner, data, fit_learner, val_split=val_split) self.learner, val_data = _training_helper(self.learner, data, fit_learner, val_split=val_split)
y_ = self.learner.predict(val_data.instances) y_ = self.learner.predict(val_data.instances)
y = val_data.labels y = val_data.labels
classes = val_data.classes_
self.cc = CC(self.learner) self.cc = CC(self.learner)
self.Pte_cond_estim_ = self.getPteCondEstim(data.classes_, y, y_) self.Pte_cond_estim_ = self.getPteCondEstim(classes, y, y_)
return self return self

View File

@ -82,6 +82,7 @@ class QuaNetTrainer(BaseQuantifier):
assert hasattr(learner, 'predict_proba'), \ assert hasattr(learner, 'predict_proba'), \
f'the learner {learner.__class__.__name__} does not seem to be able to produce posterior probabilities ' \ f'the learner {learner.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
f'since it does not implement the method "predict_proba"' f'since it does not implement the method "predict_proba"'
assert sample_size is not None, 'sample_size cannot be None'
self.learner = learner self.learner = learner
self.sample_size = sample_size self.sample_size = sample_size
self.n_epochs = n_epochs self.n_epochs = n_epochs