forked from moreo/QuaPy
bugfix in quanet
This commit is contained in:
parent
986e61620c
commit
f9b80ae437
|
@ -87,15 +87,14 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
train_posteriors = self.learner.predict_proba(train_data.instances)
|
||||
|
||||
# turn instances' original representations into embeddings
|
||||
valid_data.instances = self.learner.transform(valid_data.instances)
|
||||
train_data.instances = self.learner.transform(train_data.instances)
|
||||
valid_data_embed = LabelledCollection(self.learner.transform(valid_data.instances), valid_data.labels, self._classes_)
|
||||
train_data_embed = LabelledCollection(self.learner.transform(train_data.instances), train_data.labels, self._classes_)
|
||||
|
||||
self.quantifiers = {
|
||||
'cc': CC(self.learner).fit(None, fit_learner=False),
|
||||
'acc': ACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
||||
'pcc': PCC(self.learner).fit(None, fit_learner=False),
|
||||
'pacc': PACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
||||
# 'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False),
|
||||
}
|
||||
if classifier_data is not None:
|
||||
self.quantifiers['emq'] = EMQ(self.learner).fit(classifier_data, fit_learner=False)
|
||||
|
@ -110,9 +109,9 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
nQ = len(self.quantifiers)
|
||||
nC = data.n_classes
|
||||
self.quanet = QuaNetModule(
|
||||
doc_embedding_size=train_data.instances.shape[1],
|
||||
doc_embedding_size=train_data_embed.instances.shape[1],
|
||||
n_classes=data.n_classes,
|
||||
stats_size=nQ*nC, #+ 2*nC*nC,
|
||||
stats_size=nQ*nC,
|
||||
order_by=0 if data.binary else None,
|
||||
**self.quanet_params
|
||||
).to(self.device)
|
||||
|
@ -124,8 +123,8 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
checkpoint = self.checkpoint
|
||||
|
||||
for epoch_i in range(1, self.n_epochs):
|
||||
self.epoch(train_data, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
|
||||
self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
|
||||
self.epoch(train_data_embed, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
|
||||
self.epoch(valid_data_embed, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
|
||||
|
||||
early_stop(self.status['va-loss'], epoch_i)
|
||||
if early_stop.IMPROVED:
|
||||
|
|
Loading…
Reference in New Issue