diff --git a/quapy/method/neural.py b/quapy/method/neural.py index 4decc74..bb59f97 100644 --- a/quapy/method/neural.py +++ b/quapy/method/neural.py @@ -87,15 +87,14 @@ class QuaNetTrainer(BaseQuantifier): train_posteriors = self.learner.predict_proba(train_data.instances) # turn instances' original representations into embeddings - valid_data.instances = self.learner.transform(valid_data.instances) - train_data.instances = self.learner.transform(train_data.instances) + valid_data_embed = LabelledCollection(self.learner.transform(valid_data.instances), valid_data.labels, self._classes_) + train_data_embed = LabelledCollection(self.learner.transform(train_data.instances), train_data.labels, self._classes_) self.quantifiers = { 'cc': CC(self.learner).fit(None, fit_learner=False), 'acc': ACC(self.learner).fit(None, fit_learner=False, val_split=valid_data), 'pcc': PCC(self.learner).fit(None, fit_learner=False), 'pacc': PACC(self.learner).fit(None, fit_learner=False, val_split=valid_data), - # 'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False), } if classifier_data is not None: self.quantifiers['emq'] = EMQ(self.learner).fit(classifier_data, fit_learner=False) @@ -110,9 +109,9 @@ class QuaNetTrainer(BaseQuantifier): nQ = len(self.quantifiers) nC = data.n_classes self.quanet = QuaNetModule( - doc_embedding_size=train_data.instances.shape[1], + doc_embedding_size=train_data_embed.instances.shape[1], n_classes=data.n_classes, - stats_size=nQ*nC, #+ 2*nC*nC, + stats_size=nQ*nC, order_by=0 if data.binary else None, **self.quanet_params ).to(self.device) @@ -124,8 +123,8 @@ class QuaNetTrainer(BaseQuantifier): checkpoint = self.checkpoint for epoch_i in range(1, self.n_epochs): - self.epoch(train_data, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True) - self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False) + self.epoch(train_data_embed, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True) + self.epoch(valid_data_embed, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False) early_stop(self.status['va-loss'], epoch_i) if early_stop.IMPROVED: