Implemented custom micro F1 in pl (cpu and gpu)

This commit is contained in:
andrea 2021-01-20 11:47:51 +01:00
parent 294d7c3be7
commit 8dbe48ff7a
5 changed files with 83 additions and 108 deletions

View File

@ -88,7 +88,7 @@ class RecurrentDataset(Dataset):
return index_list return index_list
class GfunDataModule(pl.LightningDataModule): class RecurrentDataModule(pl.LightningDataModule):
def __init__(self, multilingualIndex, batchsize=64): def __init__(self, multilingualIndex, batchsize=64):
""" """
Pytorch-lightning DataModule: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html Pytorch-lightning DataModule: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
@ -105,9 +105,18 @@ class GfunDataModule(pl.LightningDataModule):
def setup(self, stage=None): def setup(self, stage=None):
if stage == 'fit' or stage is None: if stage == 'fit' or stage is None:
l_train_index, l_train_target = self.multilingualIndex.l_train() l_train_index, l_train_target = self.multilingualIndex.l_train()
# l_train_index = {l: train[:50] for l, train in l_train_index.items()}
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
self.training_dataset = RecurrentDataset(l_train_index, l_train_target, self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
l_val_index, l_val_target = self.multilingualIndex.l_val() l_val_index, l_val_target = self.multilingualIndex.l_val()
# l_val_index = {l: train[:50] for l, train in l_val_index.items()}
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
self.val_dataset = RecurrentDataset(l_val_index, l_val_target, self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
if stage == 'test' or stage is None: if stage == 'test' or stage is None:
@ -128,7 +137,7 @@ class GfunDataModule(pl.LightningDataModule):
collate_fn=self.test_dataset.collate_fn) collate_fn=self.test_dataset.collate_fn)
class BertDataModule(GfunDataModule): class BertDataModule(RecurrentDataModule):
def __init__(self, multilingualIndex, batchsize=64, max_len=512): def __init__(self, multilingualIndex, batchsize=64, max_len=512):
super().__init__(multilingualIndex, batchsize) super().__init__(multilingualIndex, batchsize)
self.max_len = max_len self.max_len = max_len

View File

@ -15,7 +15,7 @@ def main(args):
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' _DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings' EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
data = MultilingualDataset.load(_DATASET) data = MultilingualDataset.load(_DATASET)
data.set_view(languages=['it'], categories=[0,1]) data.set_view(languages=['it'], categories=[0, 1])
lX, ly = data.training() lX, ly = data.training()
lXte, lyte = data.test() lXte, lyte = data.test()
@ -28,7 +28,8 @@ def main(args):
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
# gFun = WordClassGen(n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS)
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS) gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=5, nepochs=100,
gpus=args.gpus, n_jobs=N_JOBS)
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS) # gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
gFun.fit(lX, ly) gFun.fit(lX, ly)

View File

@ -5,12 +5,10 @@ from transformers import AdamW
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.metrics import F1, Accuracy, Metric from pytorch_lightning.metrics import Metric, F1, Accuracy
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from typing import Any, Optional, Tuple
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
from models.helpers import init_embeddings from models.helpers import init_embeddings
import numpy as np from util.common import is_true, is_false
from util.evaluation import evaluate from util.evaluation import evaluate
@ -20,8 +18,9 @@ class RecurrentModel(pl.LightningModule):
""" """
def __init__(self, lPretrained, langs, output_size, hidden_size, lVocab_size, learnable_length, def __init__(self, lPretrained, langs, output_size, hidden_size, lVocab_size, learnable_length,
drop_embedding_range, drop_embedding_prop, lMuse_debug=None, multilingual_index_debug=None): drop_embedding_range, drop_embedding_prop, gpus=None):
super().__init__() super().__init__()
self.gpus = gpus
self.langs = langs self.langs = langs
self.lVocab_size = lVocab_size self.lVocab_size = lVocab_size
self.learnable_length = learnable_length self.learnable_length = learnable_length
@ -33,7 +32,7 @@ class RecurrentModel(pl.LightningModule):
self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro') self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro')
self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro') self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro')
self.accuracy = Accuracy() self.accuracy = Accuracy()
self.customMetrics = CustomMetrics(num_classes=output_size, multilabel=True, average='micro') self.customMetrics = CustomF1(num_classes=output_size, device=self.gpus)
self.lPretrained_embeddings = nn.ModuleDict() self.lPretrained_embeddings = nn.ModuleDict()
self.lLearnable_embeddings = nn.ModuleDict() self.lLearnable_embeddings = nn.ModuleDict()
@ -42,10 +41,6 @@ class RecurrentModel(pl.LightningModule):
self.n_directions = 1 self.n_directions = 1
self.dropout = nn.Dropout(0.6) self.dropout = nn.Dropout(0.6)
# TODO: debug setting
self.lMuse = lMuse_debug
self.multilingual_index_debug = multilingual_index_debug
lstm_out = 256 lstm_out = 256
ff1 = 512 ff1 = 512
ff2 = 256 ff2 = 256
@ -111,7 +106,7 @@ class RecurrentModel(pl.LightningModule):
custom = self.customMetrics(predictions, ly) custom = self.customMetrics(predictions, ly)
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('custom', custom, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return {'loss': loss} return {'loss': loss}
def validation_step(self, val_batch, batch_idx): def validation_step(self, val_batch, batch_idx):
@ -139,7 +134,6 @@ class RecurrentModel(pl.LightningModule):
accuracy = self.accuracy(predictions, ly) accuracy = self.accuracy(predictions, ly)
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
return return
# return {'pred': predictions, 'target': ly}
def embed(self, X, lang): def embed(self, X, lang):
input_list = [] input_list = []
@ -166,98 +160,56 @@ class RecurrentModel(pl.LightningModule):
return [optimizer], [scheduler] return [optimizer], [scheduler]
class CustomMetrics(Metric): class CustomF1(Metric):
def __init__( def __init__(self, num_classes, device, average='micro'):
self, """
num_classes: int, Custom F1 metric.
beta: float = 1.0, Scikit learn provides a full set of evaluation metrics, but they treat special cases differently.
threshold: float = 0.5, I.e., when the number of true positives, false positives, and false negatives amount to 0, all
average: str = "micro", affected metrics (precision, recall, and thus f1) output 0 in Scikit learn.
multilabel: bool = False, We adhere to the common practice of outputting 1 in this case since the classifier has correctly
compute_on_step: bool = True, classified all examples as negatives.
dist_sync_on_step: bool = False, :param num_classes:
process_group: Optional[Any] = None, :param device:
): :param average:
super().__init__( """
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, super().__init__()
)
self.num_classes = num_classes self.num_classes = num_classes
self.beta = beta
self.threshold = threshold
self.average = average self.average = average
self.multilabel = multilabel self.device = 'cuda' if device else 'cpu'
self.add_state('true_positive', default=torch.zeros(self.num_classes))
self.add_state('true_negative', default=torch.zeros(self.num_classes))
self.add_state('false_positive', default=torch.zeros(self.num_classes))
self.add_state('false_negative', default=torch.zeros(self.num_classes))
allowed_average = ("micro", "macro", "weighted", None) def update(self, preds, target):
if self.average not in allowed_average: true_positive, true_negative, false_positive, false_negative = self._update(preds, target)
raise ValueError('Argument `average` expected to be one of the following:'
f' {allowed_average} but got {self.average}')
self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") self.true_positive += true_positive
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") self.true_negative += true_negative
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") self.false_positive += false_positive
self.false_negative += false_negative
def update(self, preds: torch.Tensor, target: torch.Tensor): def _update(self, pred, target):
""" assert pred.shape == target.shape
Update state with predictions and targets. # preparing preds and targets for count
true_pred = is_true(pred, self.device)
false_pred = is_false(pred, self.device)
true_target = is_true(target, self.device)
false_target = is_false(target, self.device)
Args: tp = torch.sum(true_pred * true_target, dim=0)
preds: Predictions from model tn = torch.sum(false_pred * false_target, dim=0)
target: Ground truth values fp = torch.sum(true_pred * false_target, dim=0)
""" fn = torch.sum(false_pred * target, dim=0)
true_positives, predicted_positives, actual_positives = _fbeta_update( return tp, tn, fp, fn
preds, target, self.num_classes, self.threshold, self.multilabel
)
self.true_positives += true_positives
self.predicted_positives += predicted_positives
self.actual_positives += actual_positives
def compute(self): def compute(self):
""" if self.average == 'micro':
Computes metrics over state. num = 2.0 * self.true_positive.sum()
""" den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum()
return _fbeta_compute(self.true_positives, self.predicted_positives, if den > 0:
self.actual_positives, self.beta, self.average) return (num / den).to(self.device)
return torch.FloatTensor([1.]).to(self.device)
if self.average == 'macro':
def _fbeta_update( raise NotImplementedError
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
threshold: float = 0.5,
multilabel: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
preds, target = _input_format_classification_one_hot(
num_classes, preds, target, threshold, multilabel
)
true_positives = torch.sum(preds * target, dim=1)
predicted_positives = torch.sum(preds, dim=1)
actual_positives = torch.sum(target, dim=1)
return true_positives, predicted_positives, actual_positives
def _fbeta_compute(
true_positives: torch.Tensor,
predicted_positives: torch.Tensor,
actual_positives: torch.Tensor,
beta: float = 1.0,
average: str = "micro"
) -> torch.Tensor:
if average == "micro":
precision = true_positives.sum().float() / predicted_positives.sum()
recall = true_positives.sum().float() / actual_positives.sum()
else:
precision = true_positives.float() / predicted_positives
recall = true_positives.float() / actual_positives
num = (1 + beta ** 2) * precision * recall
denom = beta ** 2 * precision + recall
new_num = 2 * true_positives
new_fp = predicted_positives - true_positives
new_fn = actual_positives - true_positives
new_den = 2 * true_positives + new_fp + new_fn
if new_den.sum() == 0:
# whats is the correct return type ? TODO
return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average)
return class_reduce(num, denom, weights=actual_positives, class_reduction=average)

View File

@ -327,3 +327,12 @@ def index(data, vocab, known_words, analyzer, unk_index, out_of_vocabulary):
# pbar.set_description(f'[unk = {unk_count}/{knw_count}={(100.*unk_count/knw_count):.2f}%]' # pbar.set_description(f'[unk = {unk_count}/{knw_count}={(100.*unk_count/knw_count):.2f}%]'
# f'[out = {out_count}/{knw_count}={(100.*out_count/knw_count):.2f}%]') # f'[out = {out_count}/{knw_count}={(100.*out_count/knw_count):.2f}%]')
return indexes return indexes
def is_true(tensor, device):
return torch.where(tensor == 1, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))
def is_false(tensor, device):
return torch.where(tensor == 0, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))

View File

@ -22,7 +22,7 @@ from models.pl_gru import RecurrentModel
from models.pl_bert import BertModel from models.pl_bert import BertModel
from models.lstm_class import RNNMultilingualClassifier from models.lstm_class import RNNMultilingualClassifier
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from data.datamodule import GfunDataModule, BertDataModule from data.datamodule import RecurrentDataModule, BertDataModule
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
import torch import torch
@ -144,7 +144,8 @@ class RecurrentGen(ViewGen):
# TODO: save model https://forums.pytorchlightning.ai/t/how-to-save-hparams-when-not-provided-as-argument-apparently-assigning-to-hparams-is-not-recomended/339/5 # TODO: save model https://forums.pytorchlightning.ai/t/how-to-save-hparams-when-not-provided-as-argument-apparently-assigning-to-hparams-is-not-recomended/339/5
# Problem: we are passing lPretrained to init the RecurrentModel -> incredible slow at saving (checkpoint). # Problem: we are passing lPretrained to init the RecurrentModel -> incredible slow at saving (checkpoint).
# if we do not save it is impossible to init RecurrentModel by calling RecurrentModel.load_from_checkpoint() # if we do not save it is impossible to init RecurrentModel by calling RecurrentModel.load_from_checkpoint()
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, gpus=0, n_jobs=-1, stored_path=None): def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
gpus=0, n_jobs=-1, stored_path=None):
""" """
generates document embedding by means of a Gated Recurrent Units. The model can be generates document embedding by means of a Gated Recurrent Units. The model can be
initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,). initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,).
@ -162,6 +163,7 @@ class RecurrentGen(ViewGen):
self.gpus = gpus self.gpus = gpus
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.stored_path = stored_path self.stored_path = stored_path
self.nepochs = nepochs
# EMBEDDINGS to be deployed # EMBEDDINGS to be deployed
self.pretrained = pretrained_embeddings self.pretrained = pretrained_embeddings
@ -193,7 +195,8 @@ class RecurrentGen(ViewGen):
lVocab_size=lvocab_size, lVocab_size=lvocab_size,
learnable_length=learnable_length, learnable_length=learnable_length,
drop_embedding_range=self.multilingualIndex.sup_range, drop_embedding_range=self.multilingualIndex.sup_range,
drop_embedding_prop=0.5 drop_embedding_prop=0.5,
gpus=self.gpus
) )
def fit(self, lX, ly): def fit(self, lX, ly):
@ -204,8 +207,9 @@ class RecurrentGen(ViewGen):
:param ly: :param ly:
:return: :return:
""" """
recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size) recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size)
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False) trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
checkpoint_callback=False)
# vanilla_torch_model = torch.load( # vanilla_torch_model = torch.load(
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle') # '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')