bugfix in quanet

This commit is contained in:
Alejandro Moreo Fernandez 2021-07-02 10:20:42 +02:00
parent 986e61620c
commit f9b80ae437
1 changed files with 6 additions and 7 deletions

View File

@ -87,15 +87,14 @@ class QuaNetTrainer(BaseQuantifier):
train_posteriors = self.learner.predict_proba(train_data.instances)
# turn instances' original representations into embeddings
valid_data.instances = self.learner.transform(valid_data.instances)
train_data.instances = self.learner.transform(train_data.instances)
valid_data_embed = LabelledCollection(self.learner.transform(valid_data.instances), valid_data.labels, self._classes_)
train_data_embed = LabelledCollection(self.learner.transform(train_data.instances), train_data.labels, self._classes_)
self.quantifiers = {
'cc': CC(self.learner).fit(None, fit_learner=False),
'acc': ACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
'pcc': PCC(self.learner).fit(None, fit_learner=False),
'pacc': PACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
# 'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False),
}
if classifier_data is not None:
self.quantifiers['emq'] = EMQ(self.learner).fit(classifier_data, fit_learner=False)
@ -110,9 +109,9 @@ class QuaNetTrainer(BaseQuantifier):
nQ = len(self.quantifiers)
nC = data.n_classes
self.quanet = QuaNetModule(
doc_embedding_size=train_data.instances.shape[1],
doc_embedding_size=train_data_embed.instances.shape[1],
n_classes=data.n_classes,
stats_size=nQ*nC, #+ 2*nC*nC,
stats_size=nQ*nC,
order_by=0 if data.binary else None,
**self.quanet_params
).to(self.device)
@ -124,8 +123,8 @@ class QuaNetTrainer(BaseQuantifier):
checkpoint = self.checkpoint
for epoch_i in range(1, self.n_epochs):
self.epoch(train_data, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
self.epoch(train_data_embed, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
self.epoch(valid_data_embed, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
early_stop(self.status['va-loss'], epoch_i)
if early_stop.IMPROVED: