Implemented custom micro F1 in pl (cpu and gpu)

This commit is contained in:
andrea 2021-01-20 11:47:51 +01:00
parent 294d7c3be7
commit 8dbe48ff7a
5 changed files with 83 additions and 108 deletions

View File

@ -88,7 +88,7 @@ class RecurrentDataset(Dataset):
return index_list
class GfunDataModule(pl.LightningDataModule):
class RecurrentDataModule(pl.LightningDataModule):
def __init__(self, multilingualIndex, batchsize=64):
"""
Pytorch-lightning DataModule: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
@ -105,9 +105,18 @@ class GfunDataModule(pl.LightningDataModule):
def setup(self, stage=None):
if stage == 'fit' or stage is None:
l_train_index, l_train_target = self.multilingualIndex.l_train()
# l_train_index = {l: train[:50] for l, train in l_train_index.items()}
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
lPad_index=self.multilingualIndex.l_pad())
l_val_index, l_val_target = self.multilingualIndex.l_val()
# l_val_index = {l: train[:50] for l, train in l_val_index.items()}
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
lPad_index=self.multilingualIndex.l_pad())
if stage == 'test' or stage is None:
@ -128,7 +137,7 @@ class GfunDataModule(pl.LightningDataModule):
collate_fn=self.test_dataset.collate_fn)
class BertDataModule(GfunDataModule):
class BertDataModule(RecurrentDataModule):
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
super().__init__(multilingualIndex, batchsize)
self.max_len = max_len

View File

@ -15,7 +15,7 @@ def main(args):
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
data = MultilingualDataset.load(_DATASET)
data.set_view(languages=['it'], categories=[0,1])
data.set_view(languages=['it'], categories=[0, 1])
lX, ly = data.training()
lXte, lyte = data.test()
@ -28,7 +28,8 @@ def main(args):
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
# gFun = WordClassGen(n_jobs=N_JOBS)
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS)
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=5, nepochs=100,
gpus=args.gpus, n_jobs=N_JOBS)
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
gFun.fit(lX, ly)

View File

@ -5,12 +5,10 @@ from transformers import AdamW
import torch.nn.functional as F
from torch.autograd import Variable
import pytorch_lightning as pl
from pytorch_lightning.metrics import F1, Accuracy, Metric
from pytorch_lightning.metrics import Metric, F1, Accuracy
from torch.optim.lr_scheduler import StepLR
from typing import Any, Optional, Tuple
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
from models.helpers import init_embeddings
import numpy as np
from util.common import is_true, is_false
from util.evaluation import evaluate
@ -20,8 +18,9 @@ class RecurrentModel(pl.LightningModule):
"""
def __init__(self, lPretrained, langs, output_size, hidden_size, lVocab_size, learnable_length,
drop_embedding_range, drop_embedding_prop, lMuse_debug=None, multilingual_index_debug=None):
drop_embedding_range, drop_embedding_prop, gpus=None):
super().__init__()
self.gpus = gpus
self.langs = langs
self.lVocab_size = lVocab_size
self.learnable_length = learnable_length
@ -33,7 +32,7 @@ class RecurrentModel(pl.LightningModule):
self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro')
self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro')
self.accuracy = Accuracy()
self.customMetrics = CustomMetrics(num_classes=output_size, multilabel=True, average='micro')
self.customMetrics = CustomF1(num_classes=output_size, device=self.gpus)
self.lPretrained_embeddings = nn.ModuleDict()
self.lLearnable_embeddings = nn.ModuleDict()
@ -42,10 +41,6 @@ class RecurrentModel(pl.LightningModule):
self.n_directions = 1
self.dropout = nn.Dropout(0.6)
# TODO: debug setting
self.lMuse = lMuse_debug
self.multilingual_index_debug = multilingual_index_debug
lstm_out = 256
ff1 = 512
ff2 = 256
@ -111,7 +106,7 @@ class RecurrentModel(pl.LightningModule):
custom = self.customMetrics(predictions, ly)
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('custom', custom, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return {'loss': loss}
def validation_step(self, val_batch, batch_idx):
@ -139,7 +134,6 @@ class RecurrentModel(pl.LightningModule):
accuracy = self.accuracy(predictions, ly)
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
return
# return {'pred': predictions, 'target': ly}
def embed(self, X, lang):
input_list = []
@ -166,98 +160,56 @@ class RecurrentModel(pl.LightningModule):
return [optimizer], [scheduler]
class CustomMetrics(Metric):
def __init__(
self,
num_classes: int,
beta: float = 1.0,
threshold: float = 0.5,
average: str = "micro",
multilabel: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group,
)
class CustomF1(Metric):
def __init__(self, num_classes, device, average='micro'):
"""
Custom F1 metric.
Scikit learn provides a full set of evaluation metrics, but they treat special cases differently.
I.e., when the number of true positives, false positives, and false negatives amount to 0, all
affected metrics (precision, recall, and thus f1) output 0 in Scikit learn.
We adhere to the common practice of outputting 1 in this case since the classifier has correctly
classified all examples as negatives.
:param num_classes:
:param device:
:param average:
"""
super().__init__()
self.num_classes = num_classes
self.beta = beta
self.threshold = threshold
self.average = average
self.multilabel = multilabel
self.device = 'cuda' if device else 'cpu'
self.add_state('true_positive', default=torch.zeros(self.num_classes))
self.add_state('true_negative', default=torch.zeros(self.num_classes))
self.add_state('false_positive', default=torch.zeros(self.num_classes))
self.add_state('false_negative', default=torch.zeros(self.num_classes))
allowed_average = ("micro", "macro", "weighted", None)
if self.average not in allowed_average:
raise ValueError('Argument `average` expected to be one of the following:'
f' {allowed_average} but got {self.average}')
def update(self, preds, target):
true_positive, true_negative, false_positive, false_negative = self._update(preds, target)
self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.true_positive += true_positive
self.true_negative += true_negative
self.false_positive += false_positive
self.false_negative += false_negative
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
def _update(self, pred, target):
assert pred.shape == target.shape
# preparing preds and targets for count
true_pred = is_true(pred, self.device)
false_pred = is_false(pred, self.device)
true_target = is_true(target, self.device)
false_target = is_false(target, self.device)
Args:
preds: Predictions from model
target: Ground truth values
"""
true_positives, predicted_positives, actual_positives = _fbeta_update(
preds, target, self.num_classes, self.threshold, self.multilabel
)
self.true_positives += true_positives
self.predicted_positives += predicted_positives
self.actual_positives += actual_positives
tp = torch.sum(true_pred * true_target, dim=0)
tn = torch.sum(false_pred * false_target, dim=0)
fp = torch.sum(true_pred * false_target, dim=0)
fn = torch.sum(false_pred * target, dim=0)
return tp, tn, fp, fn
def compute(self):
"""
Computes metrics over state.
"""
return _fbeta_compute(self.true_positives, self.predicted_positives,
self.actual_positives, self.beta, self.average)
def _fbeta_update(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
threshold: float = 0.5,
multilabel: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
preds, target = _input_format_classification_one_hot(
num_classes, preds, target, threshold, multilabel
)
true_positives = torch.sum(preds * target, dim=1)
predicted_positives = torch.sum(preds, dim=1)
actual_positives = torch.sum(target, dim=1)
return true_positives, predicted_positives, actual_positives
def _fbeta_compute(
true_positives: torch.Tensor,
predicted_positives: torch.Tensor,
actual_positives: torch.Tensor,
beta: float = 1.0,
average: str = "micro"
) -> torch.Tensor:
if average == "micro":
precision = true_positives.sum().float() / predicted_positives.sum()
recall = true_positives.sum().float() / actual_positives.sum()
else:
precision = true_positives.float() / predicted_positives
recall = true_positives.float() / actual_positives
num = (1 + beta ** 2) * precision * recall
denom = beta ** 2 * precision + recall
new_num = 2 * true_positives
new_fp = predicted_positives - true_positives
new_fn = actual_positives - true_positives
new_den = 2 * true_positives + new_fp + new_fn
if new_den.sum() == 0:
# whats is the correct return type ? TODO
return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average)
return class_reduce(num, denom, weights=actual_positives, class_reduction=average)
if self.average == 'micro':
num = 2.0 * self.true_positive.sum()
den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum()
if den > 0:
return (num / den).to(self.device)
return torch.FloatTensor([1.]).to(self.device)
if self.average == 'macro':
raise NotImplementedError

View File

@ -327,3 +327,12 @@ def index(data, vocab, known_words, analyzer, unk_index, out_of_vocabulary):
# pbar.set_description(f'[unk = {unk_count}/{knw_count}={(100.*unk_count/knw_count):.2f}%]'
# f'[out = {out_count}/{knw_count}={(100.*out_count/knw_count):.2f}%]')
return indexes
def is_true(tensor, device):
return torch.where(tensor == 1, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))
def is_false(tensor, device):
return torch.where(tensor == 0, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))

View File

@ -22,7 +22,7 @@ from models.pl_gru import RecurrentModel
from models.pl_bert import BertModel
from models.lstm_class import RNNMultilingualClassifier
from pytorch_lightning import Trainer
from data.datamodule import GfunDataModule, BertDataModule
from data.datamodule import RecurrentDataModule, BertDataModule
from pytorch_lightning.loggers import TensorBoardLogger
import torch
@ -144,7 +144,8 @@ class RecurrentGen(ViewGen):
# TODO: save model https://forums.pytorchlightning.ai/t/how-to-save-hparams-when-not-provided-as-argument-apparently-assigning-to-hparams-is-not-recomended/339/5
# Problem: we are passing lPretrained to init the RecurrentModel -> incredible slow at saving (checkpoint).
# if we do not save it is impossible to init RecurrentModel by calling RecurrentModel.load_from_checkpoint()
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, gpus=0, n_jobs=-1, stored_path=None):
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
gpus=0, n_jobs=-1, stored_path=None):
"""
generates document embedding by means of a Gated Recurrent Units. The model can be
initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,).
@ -162,6 +163,7 @@ class RecurrentGen(ViewGen):
self.gpus = gpus
self.n_jobs = n_jobs
self.stored_path = stored_path
self.nepochs = nepochs
# EMBEDDINGS to be deployed
self.pretrained = pretrained_embeddings
@ -193,7 +195,8 @@ class RecurrentGen(ViewGen):
lVocab_size=lvocab_size,
learnable_length=learnable_length,
drop_embedding_range=self.multilingualIndex.sup_range,
drop_embedding_prop=0.5
drop_embedding_prop=0.5,
gpus=self.gpus
)
def fit(self, lX, ly):
@ -204,8 +207,9 @@ class RecurrentGen(ViewGen):
:param ly:
:return:
"""
recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size)
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False)
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size)
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
checkpoint_callback=False)
# vanilla_torch_model = torch.load(
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')