1
0
Fork 0
QuaPy/quapy/method/neural.py

345 lines
15 KiB
Python
Raw Normal View History

import os
from pathlib import Path
import random
2021-01-15 18:32:32 +01:00
import torch
from torch.nn import MSELoss
from torch.nn.functional import relu
2021-01-15 18:32:32 +01:00
from quapy.method.aggregative import *
from quapy.util import EarlyStop
class QuaNetTrainer(BaseQuantifier):
def __init__(self,
learner,
sample_size,
2021-01-22 09:58:12 +01:00
n_epochs=100,
tr_iter_per_poch=500,
va_iter_per_poch=100,
lr=1e-3,
2021-01-22 09:58:12 +01:00
lstm_hidden_size=64,
lstm_nlayers=1,
ff_layers=[1024, 512],
bidirectional=True,
qdrop_p=0.5,
patience=10,
checkpointdir='../checkpoint',
checkpointname=None,
device='cuda'):
assert hasattr(learner, 'transform'), \
f'the learner {learner.__class__.__name__} does not seem to be able to produce document embeddings ' \
f'since it does not implement the method "transform"'
assert hasattr(learner, 'predict_proba'), \
f'the learner {learner.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
f'since it does not implement the method "predict_proba"'
self.learner = learner
self.sample_size = sample_size
self.n_epochs = n_epochs
self.tr_iter = tr_iter_per_poch
self.va_iter = va_iter_per_poch
self.lr = lr
self.quanet_params = {
'lstm_hidden_size': lstm_hidden_size,
'lstm_nlayers': lstm_nlayers,
'ff_layers': ff_layers,
'bidirectional': bidirectional,
'qdrop_p': qdrop_p
}
self.patience = patience
os.makedirs(checkpointdir, exist_ok=True)
if checkpointname is None:
local_random = random.Random()
random_code = '-'.join(str(local_random.randint(0, 1000000)) for _ in range(5))
checkpointname = 'QuaNet-'+random_code
self.checkpointdir = checkpointdir
self.checkpoint = os.path.join(checkpointdir, checkpointname)
self.device = torch.device(device)
self.__check_params_colision(self.quanet_params, self.learner.get_params())
2021-01-22 09:58:12 +01:00
def fit(self, data: LabelledCollection, fit_learner=True):
"""
:param data: the training data on which to train QuaNet. If fit_learner=True, the data will be split in
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
fit_learner=False, the data will be split in 66/34 for training QuaNet and validating it, respectively.
:param fit_learner: if true, trains the classifier on a split containing 40% of the data
:return: self
"""
2021-01-20 12:35:14 +01:00
classifier_data, unused_data = data.split_stratified(0.4)
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
2021-01-22 09:58:12 +01:00
print('Classifier data: ', len(classifier_data))
print('Q-Training data: ', len(train_data))
print('Q-Valid data: ', len(valid_data))
2021-01-20 12:35:14 +01:00
# estimate the hard and soft stats tpr and fpr of the classifier
self.tr_prev = data.prevalence()
2021-01-22 09:58:12 +01:00
self.learner.fit(*classifier_data.Xy)
2021-01-20 12:35:14 +01:00
self.quantifiers = {
'cc': CC(self.learner).fit(classifier_data, fit_learner=False),
2021-01-22 09:58:12 +01:00
'acc': ACC(self.learner).fit(classifier_data, fit_learner=False, val_split=valid_data),
2021-01-20 12:35:14 +01:00
'pcc': PCC(self.learner).fit(classifier_data, fit_learner=False),
2021-01-22 09:58:12 +01:00
'pacc': PACC(self.learner).fit(classifier_data, fit_learner=False, val_split=valid_data),
2021-01-20 12:35:14 +01:00
'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False),
}
# compute the posterior probabilities of the instances
valid_posteriors = self.learner.predict_proba(valid_data.instances)
train_posteriors = self.learner.predict_proba(train_data.instances)
2021-01-22 09:58:12 +01:00
# turn instances' original representations into embeddings
valid_data.instances = self.learner.transform(valid_data.instances)
train_data.instances = self.learner.transform(train_data.instances)
self.status = {
'tr-loss': -1,
'va-loss': -1,
2021-01-22 09:58:12 +01:00
'tr-mae': -1,
'va-mae': -1,
}
nQ = len(self.quantifiers)
nC = data.n_classes
self.quanet = QuaNetModule(
doc_embedding_size=train_data.instances.shape[1],
n_classes=data.n_classes,
2021-01-22 09:58:12 +01:00
stats_size=nQ*nC, #+ 2*nC*nC,
order_by=0 if data.binary else None,
**self.quanet_params
).to(self.device)
2021-01-22 09:58:12 +01:00
print(self.quanet)
self.optim = torch.optim.Adam(self.quanet.parameters(), lr=self.lr)
early_stop = EarlyStop(self.patience, lower_is_better=True)
checkpoint = self.checkpoint
for epoch_i in range(1, self.n_epochs):
self.epoch(train_data, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
early_stop(self.status['va-loss'], epoch_i)
if early_stop.IMPROVED:
torch.save(self.quanet.state_dict(), checkpoint)
elif early_stop.STOP:
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
f'for epoch {early_stop.best_epoch}')
self.quanet.load_state_dict(torch.load(checkpoint))
2021-01-20 12:35:14 +01:00
#self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
break
return self
def get_aggregative_estims(self, posteriors):
label_predictions = np.argmax(posteriors, axis=-1)
prevs_estim = []
for quantifier in self.quantifiers.values():
predictions = posteriors if quantifier.probabilistic else label_predictions
prevs_estim.extend(quantifier.aggregate(predictions))
# add the class-conditional predictions P(y'i|yj) from ACC and PACC
2021-01-22 09:58:12 +01:00
# prevs_estim.extend(self.quantifiers['acc'].Pte_cond_estim_.flatten())
# prevs_estim.extend(self.quantifiers['pacc'].Pte_cond_estim_.flatten())
return prevs_estim
def quantify(self, instances, *args):
posteriors = self.learner.predict_proba(instances)
embeddings = self.learner.transform(instances)
quant_estims = self.get_aggregative_estims(posteriors)
self.quanet.eval()
with torch.no_grad():
prevalence = self.quanet.forward(embeddings, posteriors, quant_estims)
if self.device == torch.device('cuda'):
prevalence = prevalence.cpu()
prevalence = prevalence.numpy().flatten()
return prevalence
def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
mse_loss = MSELoss()
2021-01-22 09:58:12 +01:00
# prevpoints = F.get_nprevpoints_approximation(iterations, self.quanet.n_classes)
# iterations = F.num_prevalence_combinations(prevpoints, self.quanet.n_classes)
self.quanet.train(mode=train)
losses = []
2021-01-22 09:58:12 +01:00
mae_errors = []
if train==False:
prevpoints = F.get_nprevpoints_approximation(iterations, self.quanet.n_classes)
iterations = F.num_prevalence_combinations(prevpoints, self.quanet.n_classes)
with qp.util.temp_seed(0):
sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
else:
# sampling_index_gen = data.artificial_sampling_index_generator(self.sample_size, prevpoints)
sampling_index_gen = [data.sampling_index(self.sample_size, *prev) for prev in F.uniform_simplex_sampling(data.n_classes, iterations)]
pbar = tqdm(sampling_index_gen, total=iterations) if train else sampling_index_gen
rand_it_show = np.random.randint(iterations)
for it, index in enumerate(pbar):
sample_data = data.sampling_from_index(index)
sample_posteriors = posteriors[index]
quant_estims = self.get_aggregative_estims(sample_posteriors)
ptrue = torch.as_tensor([sample_data.prevalence()], dtype=torch.float, device=self.device)
if train:
self.optim.zero_grad()
phat = self.quanet.forward(sample_data.instances, sample_posteriors, quant_estims)
loss = mse_loss(phat, ptrue)
2021-01-22 09:58:12 +01:00
mae = mae_loss(phat, ptrue)
loss.backward()
self.optim.step()
else:
with torch.no_grad():
phat = self.quanet.forward(sample_data.instances, sample_posteriors, quant_estims)
loss = mse_loss(phat, ptrue)
2021-01-22 09:58:12 +01:00
mae = mae_loss(phat, ptrue)
losses.append(loss.item())
2021-01-22 09:58:12 +01:00
mae_errors.append(mae.item())
mse = np.mean(losses)
mae = np.mean(mae_errors)
if train:
self.status['tr-loss'] = mse
self.status['tr-mae'] = mae
else:
self.status['va-loss'] = mse
self.status['va-mae'] = mae
if train:
pbar.set_description(f'[QuaNet] '
f'epoch={epoch} [it={it}/{iterations}]\t'
f'tr-mseloss={self.status["tr-loss"]:.5f} tr-maeloss={self.status["tr-mae"]:.5f}\t'
f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} '
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
# if it==rand_it_show:
# print()
# print('='*100)
# print('Training: ' if train else 'Validation:')
# print('=' * 100)
# print('True: ', ptrue.cpu().numpy().flatten())
# print('Estim: ', phat.detach().cpu().numpy().flatten())
# for pred, name in zip(np.asarray(quant_estims).reshape(-1,data.n_classes),
# ['cc', 'acc', 'pcc', 'pacc', 'emq', 'Pte[acc]','','','Pte[pacc]','','']):
# print(name, pred)
def get_params(self, deep=True):
return {**self.learner.get_params(), **self.quanet_params}
def set_params(self, **parameters):
learner_params={}
for key, val in parameters.items():
if key in self.quanet_params:
self.quanet_params[key]=val
else:
learner_params[key] = val
self.learner.set_params(**learner_params)
def __check_params_colision(self, quanet_params, learner_params):
quanet_keys = set(quanet_params.keys())
learner_keys = set(learner_params.keys())
intersection = quanet_keys.intersection(learner_keys)
if len(intersection) > 0:
raise ValueError(f'the use of parameters {intersection} is ambiguous sine those can refer to '
f'the parameters of QuaNet or the learner {self.learner.__class__.__name__}')
def clean_checkpoint(self):
os.remove(self.checkpoint)
def clean_checkpoint_dir(self):
import shutil
shutil.rmtree(self.checkpointdir, ignore_errors=True)
2021-01-22 09:58:12 +01:00
def mae_loss(output, target):
return torch.mean(torch.abs(output - target))
class QuaNetModule(torch.nn.Module):
def __init__(self,
doc_embedding_size,
n_classes,
stats_size,
lstm_hidden_size=64,
lstm_nlayers=1,
ff_layers=[1024, 512],
bidirectional=True,
qdrop_p=0.5,
2021-01-22 09:58:12 +01:00
order_by=0):
super().__init__()
self.n_classes = n_classes
self.order_by = order_by
self.hidden_size = lstm_hidden_size
self.nlayers = lstm_nlayers
self.bidirectional = bidirectional
self.ndirections = 2 if self.bidirectional else 1
self.qdrop_p = qdrop_p
self.lstm = torch.nn.LSTM(doc_embedding_size + n_classes, # +n_classes stands for the posterior probs. (concatenated)
lstm_hidden_size, lstm_nlayers, bidirectional=bidirectional,
dropout=qdrop_p, batch_first=True)
self.dropout = torch.nn.Dropout(self.qdrop_p)
lstm_output_size = self.hidden_size * self.ndirections
ff_input_size = lstm_output_size + stats_size
prev_size = ff_input_size
self.ff_layers = torch.nn.ModuleList()
for lin_size in ff_layers:
self.ff_layers.append(torch.nn.Linear(prev_size, lin_size))
prev_size = lin_size
self.output = torch.nn.Linear(prev_size, n_classes)
@property
def device(self):
return torch.device('cuda') if next(self.parameters()).is_cuda else torch.device('cpu')
def init_hidden(self):
directions = 2 if self.bidirectional else 1
var_hidden = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
var_cell = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
if next(self.lstm.parameters()).is_cuda:
var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
return var_hidden, var_cell
def forward(self, doc_embeddings, doc_posteriors, statistics):
device = self.device
doc_embeddings = torch.as_tensor(doc_embeddings, dtype=torch.float, device=device)
doc_posteriors = torch.as_tensor(doc_posteriors, dtype=torch.float, device=device)
statistics = torch.as_tensor(statistics, dtype=torch.float, device=device)
if self.order_by is not None:
order = torch.argsort(doc_posteriors[:, self.order_by])
doc_embeddings = doc_embeddings[order]
doc_posteriors = doc_posteriors[order]
embeded_posteriors = torch.cat((doc_embeddings, doc_posteriors), dim=-1)
# the entire set represents only one instance in quapy contexts, and so the batch_size=1
2021-01-22 09:58:12 +01:00
# the shape should be (1, number-of-instances, embedding-size + n_classes)
embeded_posteriors = embeded_posteriors.unsqueeze(0)
self.lstm.flatten_parameters()
_, (rnn_hidden,_) = self.lstm(embeded_posteriors, self.init_hidden())
2021-01-22 09:58:12 +01:00
rnn_hidden = rnn_hidden.view(self.nlayers, self.ndirections, 1, self.hidden_size)
quant_embedding = rnn_hidden[0].view(-1)
quant_embedding = torch.cat((quant_embedding, statistics))
abstracted = quant_embedding.unsqueeze(0)
for linear in self.ff_layers:
abstracted = self.dropout(relu(linear(abstracted)))
logits = self.output(abstracted).view(1, -1)
prevalence = torch.softmax(logits, -1)
return prevalence