bugfix in quanet
This commit is contained in:
parent
986e61620c
commit
f9b80ae437
|
@ -87,15 +87,14 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
train_posteriors = self.learner.predict_proba(train_data.instances)
|
train_posteriors = self.learner.predict_proba(train_data.instances)
|
||||||
|
|
||||||
# turn instances' original representations into embeddings
|
# turn instances' original representations into embeddings
|
||||||
valid_data.instances = self.learner.transform(valid_data.instances)
|
valid_data_embed = LabelledCollection(self.learner.transform(valid_data.instances), valid_data.labels, self._classes_)
|
||||||
train_data.instances = self.learner.transform(train_data.instances)
|
train_data_embed = LabelledCollection(self.learner.transform(train_data.instances), train_data.labels, self._classes_)
|
||||||
|
|
||||||
self.quantifiers = {
|
self.quantifiers = {
|
||||||
'cc': CC(self.learner).fit(None, fit_learner=False),
|
'cc': CC(self.learner).fit(None, fit_learner=False),
|
||||||
'acc': ACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
'acc': ACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
||||||
'pcc': PCC(self.learner).fit(None, fit_learner=False),
|
'pcc': PCC(self.learner).fit(None, fit_learner=False),
|
||||||
'pacc': PACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
'pacc': PACC(self.learner).fit(None, fit_learner=False, val_split=valid_data),
|
||||||
# 'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False),
|
|
||||||
}
|
}
|
||||||
if classifier_data is not None:
|
if classifier_data is not None:
|
||||||
self.quantifiers['emq'] = EMQ(self.learner).fit(classifier_data, fit_learner=False)
|
self.quantifiers['emq'] = EMQ(self.learner).fit(classifier_data, fit_learner=False)
|
||||||
|
@ -110,9 +109,9 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
nQ = len(self.quantifiers)
|
nQ = len(self.quantifiers)
|
||||||
nC = data.n_classes
|
nC = data.n_classes
|
||||||
self.quanet = QuaNetModule(
|
self.quanet = QuaNetModule(
|
||||||
doc_embedding_size=train_data.instances.shape[1],
|
doc_embedding_size=train_data_embed.instances.shape[1],
|
||||||
n_classes=data.n_classes,
|
n_classes=data.n_classes,
|
||||||
stats_size=nQ*nC, #+ 2*nC*nC,
|
stats_size=nQ*nC,
|
||||||
order_by=0 if data.binary else None,
|
order_by=0 if data.binary else None,
|
||||||
**self.quanet_params
|
**self.quanet_params
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
@ -124,8 +123,8 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
|
|
||||||
for epoch_i in range(1, self.n_epochs):
|
for epoch_i in range(1, self.n_epochs):
|
||||||
self.epoch(train_data, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
|
self.epoch(train_data_embed, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
|
||||||
self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
|
self.epoch(valid_data_embed, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
|
||||||
|
|
||||||
early_stop(self.status['va-loss'], epoch_i)
|
early_stop(self.status['va-loss'], epoch_i)
|
||||||
if early_stop.IMPROVED:
|
if early_stop.IMPROVED:
|
||||||
|
|
Loading…
Reference in New Issue