2020-12-29 20:33:59 +01:00
|
|
|
import os
|
|
|
|
from pathlib import Path
|
2021-01-18 19:14:04 +01:00
|
|
|
import random
|
2021-01-15 18:32:32 +01:00
|
|
|
|
2020-12-29 20:33:59 +01:00
|
|
|
import torch
|
|
|
|
from torch.nn import MSELoss
|
|
|
|
from torch.nn.functional import relu
|
2021-01-15 18:32:32 +01:00
|
|
|
|
|
|
|
from quapy.method.aggregative import *
|
|
|
|
from quapy.util import EarlyStop
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
|
|
|
|
class QuaNetTrainer(BaseQuantifier):
|
2021-12-15 16:39:57 +01:00
|
|
|
"""
|
|
|
|
Implementation of `QuaNet <https://dl.acm.org/doi/abs/10.1145/3269206.3269287>`_, a neural network for
|
|
|
|
quantification. This implementation uses `PyTorch <https://pytorch.org/>`_ and can take advantage of GPU
|
|
|
|
for speeding-up the training phase.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> import quapy as qp
|
|
|
|
>>> from quapy.method.meta import QuaNet
|
|
|
|
>>> from quapy.classification.neural import NeuralClassifierTrainer, CNNnet
|
|
|
|
>>>
|
|
|
|
>>> # use samples of 100 elements
|
|
|
|
>>> qp.environ['SAMPLE_SIZE'] = 100
|
|
|
|
>>>
|
|
|
|
>>> # load the kindle dataset as text, and convert words to numerical indexes
|
|
|
|
>>> dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
|
|
|
|
>>> qp.data.preprocessing.index(dataset, min_df=5, inplace=True)
|
|
|
|
>>>
|
|
|
|
>>> # the text classifier is a CNN trained by NeuralClassifierTrainer
|
|
|
|
>>> cnn = CNNnet(dataset.vocabulary_size, dataset.n_classes)
|
2023-01-27 18:13:23 +01:00
|
|
|
>>> classifier = NeuralClassifierTrainer(cnn, device='cuda')
|
2021-12-15 16:39:57 +01:00
|
|
|
>>>
|
|
|
|
>>> # train QuaNet (QuaNet is an alias to QuaNetTrainer)
|
2023-01-27 18:13:23 +01:00
|
|
|
>>> model = QuaNet(classifier, qp.environ['SAMPLE_SIZE'], device='cuda')
|
2021-12-15 16:39:57 +01:00
|
|
|
>>> model.fit(dataset.training)
|
|
|
|
>>> estim_prevalence = model.quantify(dataset.test.instances)
|
|
|
|
|
2023-01-27 18:13:23 +01:00
|
|
|
:param classifier: an object implementing `fit` (i.e., that can be trained on labelled data),
|
2021-12-15 16:39:57 +01:00
|
|
|
`predict_proba` (i.e., that can generate posterior probabilities of unlabelled examples) and
|
|
|
|
`transform` (i.e., that can generate embedded representations of the unlabelled instances).
|
|
|
|
:param sample_size: integer, the sample size
|
|
|
|
:param n_epochs: integer, maximum number of training epochs
|
|
|
|
:param tr_iter_per_poch: integer, number of training iterations before considering an epoch complete
|
|
|
|
:param va_iter_per_poch: integer, number of validation iterations to perform after each epoch
|
|
|
|
:param lr: float, the learning rate
|
|
|
|
:param lstm_hidden_size: integer, hidden dimensionality of the LSTM cells
|
|
|
|
:param lstm_nlayers: integer, number of LSTM layers
|
|
|
|
:param ff_layers: list of integers, dimensions of the densely-connected FF layers on top of the
|
|
|
|
quantification embedding
|
|
|
|
:param bidirectional: boolean, indicates whether the LSTM is bidirectional or not
|
|
|
|
:param qdrop_p: float, dropout probability
|
|
|
|
:param patience: integer, number of epochs showing no improvement in the validation set before stopping the
|
|
|
|
training phase (early stopping)
|
|
|
|
:param checkpointdir: string, a path where to store models' checkpoints
|
|
|
|
:param checkpointname: string (optional), the name of the model's checkpoint
|
|
|
|
:param device: string, indicate "cpu" or "cuda"
|
|
|
|
"""
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
def __init__(self,
|
2023-01-27 18:13:23 +01:00
|
|
|
classifier,
|
2020-12-29 20:33:59 +01:00
|
|
|
sample_size,
|
2021-01-22 09:58:12 +01:00
|
|
|
n_epochs=100,
|
|
|
|
tr_iter_per_poch=500,
|
|
|
|
va_iter_per_poch=100,
|
2020-12-29 20:33:59 +01:00
|
|
|
lr=1e-3,
|
2021-01-22 09:58:12 +01:00
|
|
|
lstm_hidden_size=64,
|
|
|
|
lstm_nlayers=1,
|
2020-12-29 20:33:59 +01:00
|
|
|
ff_layers=[1024, 512],
|
|
|
|
bidirectional=True,
|
|
|
|
qdrop_p=0.5,
|
2021-01-18 19:14:04 +01:00
|
|
|
patience=10,
|
|
|
|
checkpointdir='../checkpoint',
|
|
|
|
checkpointname=None,
|
|
|
|
device='cuda'):
|
2021-12-15 16:39:57 +01:00
|
|
|
|
2023-01-27 18:13:23 +01:00
|
|
|
assert hasattr(classifier, 'transform'), \
|
|
|
|
f'the classifier {classifier.__class__.__name__} does not seem to be able to produce document embeddings ' \
|
2020-12-29 20:33:59 +01:00
|
|
|
f'since it does not implement the method "transform"'
|
2023-01-27 18:13:23 +01:00
|
|
|
assert hasattr(classifier, 'predict_proba'), \
|
|
|
|
f'the classifier {classifier.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
|
2020-12-29 20:33:59 +01:00
|
|
|
f'since it does not implement the method "predict_proba"'
|
2023-01-27 18:13:23 +01:00
|
|
|
self.classifier = classifier
|
2020-12-29 20:33:59 +01:00
|
|
|
self.sample_size = sample_size
|
|
|
|
self.n_epochs = n_epochs
|
|
|
|
self.tr_iter = tr_iter_per_poch
|
|
|
|
self.va_iter = va_iter_per_poch
|
|
|
|
self.lr = lr
|
|
|
|
self.quanet_params = {
|
|
|
|
'lstm_hidden_size': lstm_hidden_size,
|
|
|
|
'lstm_nlayers': lstm_nlayers,
|
|
|
|
'ff_layers': ff_layers,
|
|
|
|
'bidirectional': bidirectional,
|
|
|
|
'qdrop_p': qdrop_p
|
|
|
|
}
|
|
|
|
|
|
|
|
self.patience = patience
|
2021-01-18 19:14:04 +01:00
|
|
|
if checkpointname is None:
|
|
|
|
local_random = random.Random()
|
|
|
|
random_code = '-'.join(str(local_random.randint(0, 1000000)) for _ in range(5))
|
|
|
|
checkpointname = 'QuaNet-'+random_code
|
|
|
|
self.checkpointdir = checkpointdir
|
|
|
|
self.checkpoint = os.path.join(checkpointdir, checkpointname)
|
2020-12-29 20:33:59 +01:00
|
|
|
self.device = torch.device(device)
|
|
|
|
|
2023-01-27 18:13:23 +01:00
|
|
|
self.__check_params_colision(self.quanet_params, self.classifier.get_params())
|
2021-05-04 17:09:13 +02:00
|
|
|
self._classes_ = None
|
2020-12-29 20:33:59 +01:00
|
|
|
|
2023-01-27 18:13:23 +01:00
|
|
|
def fit(self, data: LabelledCollection, fit_classifier=True):
|
2020-12-29 20:33:59 +01:00
|
|
|
"""
|
2021-12-15 15:27:43 +01:00
|
|
|
Trains QuaNet.
|
|
|
|
|
2023-01-27 18:13:23 +01:00
|
|
|
:param data: the training data on which to train QuaNet. If `fit_classifier=True`, the data will be split in
|
2021-12-15 15:27:43 +01:00
|
|
|
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
|
2023-01-27 18:13:23 +01:00
|
|
|
`fit_classifier=False`, the data will be split in 66/34 for training QuaNet and validating it, respectively.
|
|
|
|
:param fit_classifier: if True, trains the classifier on a split containing 40% of the data
|
2020-12-29 20:33:59 +01:00
|
|
|
:return: self
|
|
|
|
"""
|
2021-05-04 17:09:13 +02:00
|
|
|
self._classes_ = data.classes_
|
2021-01-22 10:03:08 +01:00
|
|
|
os.makedirs(self.checkpointdir, exist_ok=True)
|
2021-01-20 12:35:14 +01:00
|
|
|
|
2023-01-27 18:13:23 +01:00
|
|
|
if fit_classifier:
|
2021-06-21 11:13:14 +02:00
|
|
|
classifier_data, unused_data = data.split_stratified(0.4)
|
|
|
|
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
2023-01-27 18:13:23 +01:00
|
|
|
self.classifier.fit(*classifier_data.Xy)
|
2021-06-21 11:13:14 +02:00
|
|
|
else:
|
|
|
|
classifier_data = None
|
|
|
|
train_data, valid_data = data.split_stratified(0.66)
|
|
|
|
|
2021-01-20 12:35:14 +01:00
|
|
|
# estimate the hard and soft stats tpr and fpr of the classifier
|
|
|
|
self.tr_prev = data.prevalence()
|
|
|
|
|
2020-12-29 20:33:59 +01:00
|
|
|
# compute the posterior probabilities of the instances
|
2023-01-27 18:13:23 +01:00
|
|
|
valid_posteriors = self.classifier.predict_proba(valid_data.instances)
|
|
|
|
train_posteriors = self.classifier.predict_proba(train_data.instances)
|
2020-12-29 20:33:59 +01:00
|
|
|
|
2021-01-22 09:58:12 +01:00
|
|
|
# turn instances' original representations into embeddings
|
2023-01-27 18:13:23 +01:00
|
|
|
valid_data_embed = LabelledCollection(self.classifier.transform(valid_data.instances), valid_data.labels, self._classes_)
|
|
|
|
train_data_embed = LabelledCollection(self.classifier.transform(train_data.instances), train_data.labels, self._classes_)
|
2020-12-29 20:33:59 +01:00
|
|
|
|
2021-02-17 18:05:22 +01:00
|
|
|
self.quantifiers = {
|
2023-01-27 18:13:23 +01:00
|
|
|
'cc': CC(self.classifier).fit(None, fit_classifier=False),
|
|
|
|
'acc': ACC(self.classifier).fit(None, fit_classifier=False, val_split=valid_data),
|
|
|
|
'pcc': PCC(self.classifier).fit(None, fit_classifier=False),
|
|
|
|
'pacc': PACC(self.classifier).fit(None, fit_classifier=False, val_split=valid_data),
|
2021-02-17 18:05:22 +01:00
|
|
|
}
|
2021-06-21 12:55:39 +02:00
|
|
|
if classifier_data is not None:
|
2023-01-27 18:13:23 +01:00
|
|
|
self.quantifiers['emq'] = EMQ(self.classifier).fit(classifier_data, fit_classifier=False)
|
2021-02-17 18:05:22 +01:00
|
|
|
|
2020-12-29 20:33:59 +01:00
|
|
|
self.status = {
|
|
|
|
'tr-loss': -1,
|
|
|
|
'va-loss': -1,
|
2021-01-22 09:58:12 +01:00
|
|
|
'tr-mae': -1,
|
|
|
|
'va-mae': -1,
|
2020-12-29 20:33:59 +01:00
|
|
|
}
|
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
nQ = len(self.quantifiers)
|
|
|
|
nC = data.n_classes
|
2020-12-29 20:33:59 +01:00
|
|
|
self.quanet = QuaNetModule(
|
2021-07-02 10:20:42 +02:00
|
|
|
doc_embedding_size=train_data_embed.instances.shape[1],
|
2020-12-29 20:33:59 +01:00
|
|
|
n_classes=data.n_classes,
|
2021-07-02 10:20:42 +02:00
|
|
|
stats_size=nQ*nC,
|
2021-01-06 14:58:29 +01:00
|
|
|
order_by=0 if data.binary else None,
|
2020-12-29 20:33:59 +01:00
|
|
|
**self.quanet_params
|
|
|
|
).to(self.device)
|
2021-01-22 09:58:12 +01:00
|
|
|
print(self.quanet)
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
self.optim = torch.optim.Adam(self.quanet.parameters(), lr=self.lr)
|
|
|
|
early_stop = EarlyStop(self.patience, lower_is_better=True)
|
|
|
|
|
2021-01-18 19:14:04 +01:00
|
|
|
checkpoint = self.checkpoint
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
for epoch_i in range(1, self.n_epochs):
|
2021-12-15 16:39:57 +01:00
|
|
|
self._epoch(train_data_embed, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
|
|
|
|
self._epoch(valid_data_embed, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
early_stop(self.status['va-loss'], epoch_i)
|
|
|
|
if early_stop.IMPROVED:
|
|
|
|
torch.save(self.quanet.state_dict(), checkpoint)
|
|
|
|
elif early_stop.STOP:
|
|
|
|
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
|
|
|
|
f'for epoch {early_stop.best_epoch}')
|
|
|
|
self.quanet.load_state_dict(torch.load(checkpoint))
|
|
|
|
break
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
2021-12-15 16:39:57 +01:00
|
|
|
def _get_aggregative_estims(self, posteriors):
|
2020-12-29 20:33:59 +01:00
|
|
|
label_predictions = np.argmax(posteriors, axis=-1)
|
|
|
|
prevs_estim = []
|
2021-01-06 14:58:29 +01:00
|
|
|
for quantifier in self.quantifiers.values():
|
2022-05-25 19:14:33 +02:00
|
|
|
predictions = posteriors if isinstance(quantifier, AggregativeProbabilisticQuantifier) else label_predictions
|
2021-01-06 14:58:29 +01:00
|
|
|
prevs_estim.extend(quantifier.aggregate(predictions))
|
|
|
|
|
2021-06-21 11:13:14 +02:00
|
|
|
# there is no real need for adding static estims like the TPR or FPR from training since those are constant
|
2021-01-06 14:58:29 +01:00
|
|
|
|
|
|
|
return prevs_estim
|
2020-12-29 20:33:59 +01:00
|
|
|
|
2021-12-15 16:39:57 +01:00
|
|
|
def quantify(self, instances):
|
2023-01-27 18:13:23 +01:00
|
|
|
posteriors = self.classifier.predict_proba(instances)
|
|
|
|
embeddings = self.classifier.transform(instances)
|
2021-12-15 16:39:57 +01:00
|
|
|
quant_estims = self._get_aggregative_estims(posteriors)
|
2020-12-29 20:33:59 +01:00
|
|
|
self.quanet.eval()
|
|
|
|
with torch.no_grad():
|
2021-01-18 19:14:04 +01:00
|
|
|
prevalence = self.quanet.forward(embeddings, posteriors, quant_estims)
|
|
|
|
if self.device == torch.device('cuda'):
|
|
|
|
prevalence = prevalence.cpu()
|
|
|
|
prevalence = prevalence.numpy().flatten()
|
2020-12-29 20:33:59 +01:00
|
|
|
return prevalence
|
|
|
|
|
2021-12-15 16:39:57 +01:00
|
|
|
def _epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
|
2020-12-29 20:33:59 +01:00
|
|
|
mse_loss = MSELoss()
|
|
|
|
|
|
|
|
self.quanet.train(mode=train)
|
|
|
|
losses = []
|
2021-01-22 09:58:12 +01:00
|
|
|
mae_errors = []
|
|
|
|
if train==False:
|
|
|
|
prevpoints = F.get_nprevpoints_approximation(iterations, self.quanet.n_classes)
|
|
|
|
iterations = F.num_prevalence_combinations(prevpoints, self.quanet.n_classes)
|
|
|
|
with qp.util.temp_seed(0):
|
|
|
|
sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
|
|
|
|
else:
|
2021-06-21 11:13:14 +02:00
|
|
|
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in
|
|
|
|
F.uniform_simplex_sampling(data.n_classes, iterations)]
|
2021-01-22 09:58:12 +01:00
|
|
|
pbar = tqdm(sampling_index_gen, total=iterations) if train else sampling_index_gen
|
|
|
|
|
2020-12-29 20:33:59 +01:00
|
|
|
for it, index in enumerate(pbar):
|
|
|
|
sample_data = data.sampling_from_index(index)
|
|
|
|
sample_posteriors = posteriors[index]
|
2021-12-15 16:39:57 +01:00
|
|
|
quant_estims = self._get_aggregative_estims(sample_posteriors)
|
2020-12-29 20:33:59 +01:00
|
|
|
ptrue = torch.as_tensor([sample_data.prevalence()], dtype=torch.float, device=self.device)
|
|
|
|
if train:
|
|
|
|
self.optim.zero_grad()
|
|
|
|
phat = self.quanet.forward(sample_data.instances, sample_posteriors, quant_estims)
|
|
|
|
loss = mse_loss(phat, ptrue)
|
2021-01-22 09:58:12 +01:00
|
|
|
mae = mae_loss(phat, ptrue)
|
2020-12-29 20:33:59 +01:00
|
|
|
loss.backward()
|
|
|
|
self.optim.step()
|
|
|
|
else:
|
|
|
|
with torch.no_grad():
|
|
|
|
phat = self.quanet.forward(sample_data.instances, sample_posteriors, quant_estims)
|
|
|
|
loss = mse_loss(phat, ptrue)
|
2021-01-22 09:58:12 +01:00
|
|
|
mae = mae_loss(phat, ptrue)
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
losses.append(loss.item())
|
2021-01-22 09:58:12 +01:00
|
|
|
mae_errors.append(mae.item())
|
|
|
|
|
|
|
|
mse = np.mean(losses)
|
|
|
|
mae = np.mean(mae_errors)
|
|
|
|
if train:
|
|
|
|
self.status['tr-loss'] = mse
|
|
|
|
self.status['tr-mae'] = mae
|
|
|
|
else:
|
|
|
|
self.status['va-loss'] = mse
|
|
|
|
self.status['va-mae'] = mae
|
|
|
|
|
|
|
|
if train:
|
|
|
|
pbar.set_description(f'[QuaNet] '
|
|
|
|
f'epoch={epoch} [it={it}/{iterations}]\t'
|
|
|
|
f'tr-mseloss={self.status["tr-loss"]:.5f} tr-maeloss={self.status["tr-mae"]:.5f}\t'
|
|
|
|
f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} '
|
|
|
|
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
|
|
|
|
|
2020-12-29 20:33:59 +01:00
|
|
|
def get_params(self, deep=True):
|
2023-01-27 18:13:23 +01:00
|
|
|
return {**self.classifier.get_params(), **self.quanet_params}
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
def set_params(self, **parameters):
|
2021-06-21 11:13:14 +02:00
|
|
|
learner_params = {}
|
2021-01-18 19:14:04 +01:00
|
|
|
for key, val in parameters.items():
|
2020-12-29 20:33:59 +01:00
|
|
|
if key in self.quanet_params:
|
2021-06-21 11:13:14 +02:00
|
|
|
self.quanet_params[key] = val
|
2020-12-29 20:33:59 +01:00
|
|
|
else:
|
|
|
|
learner_params[key] = val
|
2023-01-27 18:13:23 +01:00
|
|
|
self.classifier.set_params(**learner_params)
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
def __check_params_colision(self, quanet_params, learner_params):
|
|
|
|
quanet_keys = set(quanet_params.keys())
|
|
|
|
learner_keys = set(learner_params.keys())
|
|
|
|
intersection = quanet_keys.intersection(learner_keys)
|
|
|
|
if len(intersection) > 0:
|
|
|
|
raise ValueError(f'the use of parameters {intersection} is ambiguous sine those can refer to '
|
2023-01-27 18:13:23 +01:00
|
|
|
f'the parameters of QuaNet or the learner {self.classifier.__class__.__name__}')
|
2020-12-29 20:33:59 +01:00
|
|
|
|
2021-01-18 19:14:04 +01:00
|
|
|
def clean_checkpoint(self):
|
2021-12-15 16:39:57 +01:00
|
|
|
"""
|
|
|
|
Removes the checkpoint
|
|
|
|
"""
|
2021-01-18 19:14:04 +01:00
|
|
|
os.remove(self.checkpoint)
|
|
|
|
|
|
|
|
def clean_checkpoint_dir(self):
|
2021-12-15 16:39:57 +01:00
|
|
|
"""
|
|
|
|
Removes anything contained in the checkpoint directory
|
|
|
|
"""
|
2021-01-18 19:14:04 +01:00
|
|
|
import shutil
|
|
|
|
shutil.rmtree(self.checkpointdir, ignore_errors=True)
|
|
|
|
|
2021-05-04 17:09:13 +02:00
|
|
|
@property
|
|
|
|
def classes_(self):
|
|
|
|
return self._classes_
|
|
|
|
|
2021-01-18 19:14:04 +01:00
|
|
|
|
2021-01-22 09:58:12 +01:00
|
|
|
def mae_loss(output, target):
|
2021-12-15 16:39:57 +01:00
|
|
|
"""
|
|
|
|
Torch-like wrapper for the Mean Absolute Error
|
|
|
|
|
|
|
|
:param output: predictions
|
|
|
|
:param target: ground truth values
|
|
|
|
:return: mean absolute error loss
|
|
|
|
"""
|
2021-01-22 09:58:12 +01:00
|
|
|
return torch.mean(torch.abs(output - target))
|
|
|
|
|
2020-12-29 20:33:59 +01:00
|
|
|
|
|
|
|
class QuaNetModule(torch.nn.Module):
|
2021-12-15 16:39:57 +01:00
|
|
|
"""
|
|
|
|
Implements the `QuaNet <https://dl.acm.org/doi/abs/10.1145/3269206.3269287>`_ forward pass.
|
|
|
|
See :class:`QuaNetTrainer` for training QuaNet.
|
|
|
|
|
|
|
|
:param doc_embedding_size: integer, the dimensionality of the document embeddings
|
|
|
|
:param n_classes: integer, number of classes
|
|
|
|
:param stats_size: integer, number of statistics estimated by simple quantification methods
|
|
|
|
:param lstm_hidden_size: integer, hidden dimensionality of the LSTM cell
|
|
|
|
:param lstm_nlayers: integer, number of LSTM layers
|
|
|
|
:param ff_layers: list of integers, dimensions of the densely-connected FF layers on top of the
|
|
|
|
quantification embedding
|
|
|
|
:param bidirectional: boolean, whether or not to use bidirectional LSTM
|
|
|
|
:param qdrop_p: float, dropout probability
|
|
|
|
:param order_by: integer, class for which the document embeddings are to be sorted
|
|
|
|
"""
|
|
|
|
|
2020-12-29 20:33:59 +01:00
|
|
|
def __init__(self,
|
|
|
|
doc_embedding_size,
|
|
|
|
n_classes,
|
|
|
|
stats_size,
|
|
|
|
lstm_hidden_size=64,
|
|
|
|
lstm_nlayers=1,
|
|
|
|
ff_layers=[1024, 512],
|
|
|
|
bidirectional=True,
|
|
|
|
qdrop_p=0.5,
|
2021-01-22 09:58:12 +01:00
|
|
|
order_by=0):
|
2021-12-15 16:39:57 +01:00
|
|
|
|
2020-12-29 20:33:59 +01:00
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.n_classes = n_classes
|
|
|
|
self.order_by = order_by
|
|
|
|
self.hidden_size = lstm_hidden_size
|
|
|
|
self.nlayers = lstm_nlayers
|
|
|
|
self.bidirectional = bidirectional
|
|
|
|
self.ndirections = 2 if self.bidirectional else 1
|
|
|
|
self.qdrop_p = qdrop_p
|
|
|
|
self.lstm = torch.nn.LSTM(doc_embedding_size + n_classes, # +n_classes stands for the posterior probs. (concatenated)
|
|
|
|
lstm_hidden_size, lstm_nlayers, bidirectional=bidirectional,
|
|
|
|
dropout=qdrop_p, batch_first=True)
|
|
|
|
self.dropout = torch.nn.Dropout(self.qdrop_p)
|
|
|
|
|
|
|
|
lstm_output_size = self.hidden_size * self.ndirections
|
|
|
|
ff_input_size = lstm_output_size + stats_size
|
|
|
|
prev_size = ff_input_size
|
|
|
|
self.ff_layers = torch.nn.ModuleList()
|
|
|
|
for lin_size in ff_layers:
|
|
|
|
self.ff_layers.append(torch.nn.Linear(prev_size, lin_size))
|
|
|
|
prev_size = lin_size
|
|
|
|
self.output = torch.nn.Linear(prev_size, n_classes)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def device(self):
|
|
|
|
return torch.device('cuda') if next(self.parameters()).is_cuda else torch.device('cpu')
|
|
|
|
|
2021-12-15 16:39:57 +01:00
|
|
|
def _init_hidden(self):
|
2020-12-29 20:33:59 +01:00
|
|
|
directions = 2 if self.bidirectional else 1
|
|
|
|
var_hidden = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
|
|
|
|
var_cell = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
|
|
|
|
if next(self.lstm.parameters()).is_cuda:
|
|
|
|
var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
|
|
|
|
return var_hidden, var_cell
|
|
|
|
|
|
|
|
def forward(self, doc_embeddings, doc_posteriors, statistics):
|
|
|
|
device = self.device
|
|
|
|
doc_embeddings = torch.as_tensor(doc_embeddings, dtype=torch.float, device=device)
|
|
|
|
doc_posteriors = torch.as_tensor(doc_posteriors, dtype=torch.float, device=device)
|
|
|
|
statistics = torch.as_tensor(statistics, dtype=torch.float, device=device)
|
|
|
|
|
|
|
|
if self.order_by is not None:
|
|
|
|
order = torch.argsort(doc_posteriors[:, self.order_by])
|
|
|
|
doc_embeddings = doc_embeddings[order]
|
|
|
|
doc_posteriors = doc_posteriors[order]
|
|
|
|
|
|
|
|
embeded_posteriors = torch.cat((doc_embeddings, doc_posteriors), dim=-1)
|
|
|
|
|
|
|
|
# the entire set represents only one instance in quapy contexts, and so the batch_size=1
|
2021-01-22 09:58:12 +01:00
|
|
|
# the shape should be (1, number-of-instances, embedding-size + n_classes)
|
2020-12-29 20:33:59 +01:00
|
|
|
embeded_posteriors = embeded_posteriors.unsqueeze(0)
|
|
|
|
|
2021-01-19 18:26:40 +01:00
|
|
|
self.lstm.flatten_parameters()
|
2021-12-15 16:39:57 +01:00
|
|
|
_, (rnn_hidden,_) = self.lstm(embeded_posteriors, self._init_hidden())
|
2021-01-22 09:58:12 +01:00
|
|
|
rnn_hidden = rnn_hidden.view(self.nlayers, self.ndirections, 1, self.hidden_size)
|
2020-12-29 20:33:59 +01:00
|
|
|
quant_embedding = rnn_hidden[0].view(-1)
|
|
|
|
quant_embedding = torch.cat((quant_embedding, statistics))
|
|
|
|
|
|
|
|
abstracted = quant_embedding.unsqueeze(0)
|
|
|
|
for linear in self.ff_layers:
|
|
|
|
abstracted = self.dropout(relu(linear(abstracted)))
|
|
|
|
|
|
|
|
logits = self.output(abstracted).view(1, -1)
|
|
|
|
prevalence = torch.softmax(logits, -1)
|
|
|
|
|
|
|
|
return prevalence
|
|
|
|
|
|
|
|
|
|
|
|
|
2021-01-18 19:14:04 +01:00
|
|
|
|
|
|
|
|