Implemented custom micro F1 in pl (cpu and gpu)
This commit is contained in:
parent
294d7c3be7
commit
8dbe48ff7a
|
@ -88,7 +88,7 @@ class RecurrentDataset(Dataset):
|
|||
return index_list
|
||||
|
||||
|
||||
class GfunDataModule(pl.LightningDataModule):
|
||||
class RecurrentDataModule(pl.LightningDataModule):
|
||||
def __init__(self, multilingualIndex, batchsize=64):
|
||||
"""
|
||||
Pytorch-lightning DataModule: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
|
@ -105,9 +105,18 @@ class GfunDataModule(pl.LightningDataModule):
|
|||
def setup(self, stage=None):
|
||||
if stage == 'fit' or stage is None:
|
||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||
|
||||
# l_train_index = {l: train[:50] for l, train in l_train_index.items()}
|
||||
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||
|
||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||
|
||||
# l_val_index = {l: train[:50] for l, train in l_val_index.items()}
|
||||
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
if stage == 'test' or stage is None:
|
||||
|
@ -128,7 +137,7 @@ class GfunDataModule(pl.LightningDataModule):
|
|||
collate_fn=self.test_dataset.collate_fn)
|
||||
|
||||
|
||||
class BertDataModule(GfunDataModule):
|
||||
class BertDataModule(RecurrentDataModule):
|
||||
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
|
||||
super().__init__(multilingualIndex, batchsize)
|
||||
self.max_len = max_len
|
||||
|
|
|
@ -15,7 +15,7 @@ def main(args):
|
|||
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||
data = MultilingualDataset.load(_DATASET)
|
||||
data.set_view(languages=['it'], categories=[0,1])
|
||||
data.set_view(languages=['it'], categories=[0, 1])
|
||||
lX, ly = data.training()
|
||||
lXte, lyte = data.test()
|
||||
|
||||
|
@ -28,7 +28,8 @@ def main(args):
|
|||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=5, nepochs=100,
|
||||
gpus=args.gpus, n_jobs=N_JOBS)
|
||||
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
||||
|
||||
gFun.fit(lX, ly)
|
||||
|
|
|
@ -5,12 +5,10 @@ from transformers import AdamW
|
|||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.metrics import F1, Accuracy, Metric
|
||||
from pytorch_lightning.metrics import Metric, F1, Accuracy
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from typing import Any, Optional, Tuple
|
||||
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
|
||||
from models.helpers import init_embeddings
|
||||
import numpy as np
|
||||
from util.common import is_true, is_false
|
||||
from util.evaluation import evaluate
|
||||
|
||||
|
||||
|
@ -20,8 +18,9 @@ class RecurrentModel(pl.LightningModule):
|
|||
"""
|
||||
|
||||
def __init__(self, lPretrained, langs, output_size, hidden_size, lVocab_size, learnable_length,
|
||||
drop_embedding_range, drop_embedding_prop, lMuse_debug=None, multilingual_index_debug=None):
|
||||
drop_embedding_range, drop_embedding_prop, gpus=None):
|
||||
super().__init__()
|
||||
self.gpus = gpus
|
||||
self.langs = langs
|
||||
self.lVocab_size = lVocab_size
|
||||
self.learnable_length = learnable_length
|
||||
|
@ -33,7 +32,7 @@ class RecurrentModel(pl.LightningModule):
|
|||
self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro')
|
||||
self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro')
|
||||
self.accuracy = Accuracy()
|
||||
self.customMetrics = CustomMetrics(num_classes=output_size, multilabel=True, average='micro')
|
||||
self.customMetrics = CustomF1(num_classes=output_size, device=self.gpus)
|
||||
|
||||
self.lPretrained_embeddings = nn.ModuleDict()
|
||||
self.lLearnable_embeddings = nn.ModuleDict()
|
||||
|
@ -42,10 +41,6 @@ class RecurrentModel(pl.LightningModule):
|
|||
self.n_directions = 1
|
||||
self.dropout = nn.Dropout(0.6)
|
||||
|
||||
# TODO: debug setting
|
||||
self.lMuse = lMuse_debug
|
||||
self.multilingual_index_debug = multilingual_index_debug
|
||||
|
||||
lstm_out = 256
|
||||
ff1 = 512
|
||||
ff2 = 256
|
||||
|
@ -111,7 +106,7 @@ class RecurrentModel(pl.LightningModule):
|
|||
custom = self.customMetrics(predictions, ly)
|
||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
self.log('custom', custom, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
return {'loss': loss}
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
|
@ -139,7 +134,6 @@ class RecurrentModel(pl.LightningModule):
|
|||
accuracy = self.accuracy(predictions, ly)
|
||||
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
return
|
||||
# return {'pred': predictions, 'target': ly}
|
||||
|
||||
def embed(self, X, lang):
|
||||
input_list = []
|
||||
|
@ -166,98 +160,56 @@ class RecurrentModel(pl.LightningModule):
|
|||
return [optimizer], [scheduler]
|
||||
|
||||
|
||||
class CustomMetrics(Metric):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 0.5,
|
||||
average: str = "micro",
|
||||
multilabel: bool = False,
|
||||
compute_on_step: bool = True,
|
||||
dist_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group,
|
||||
)
|
||||
|
||||
class CustomF1(Metric):
|
||||
def __init__(self, num_classes, device, average='micro'):
|
||||
"""
|
||||
Custom F1 metric.
|
||||
Scikit learn provides a full set of evaluation metrics, but they treat special cases differently.
|
||||
I.e., when the number of true positives, false positives, and false negatives amount to 0, all
|
||||
affected metrics (precision, recall, and thus f1) output 0 in Scikit learn.
|
||||
We adhere to the common practice of outputting 1 in this case since the classifier has correctly
|
||||
classified all examples as negatives.
|
||||
:param num_classes:
|
||||
:param device:
|
||||
:param average:
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.beta = beta
|
||||
self.threshold = threshold
|
||||
self.average = average
|
||||
self.multilabel = multilabel
|
||||
self.device = 'cuda' if device else 'cpu'
|
||||
self.add_state('true_positive', default=torch.zeros(self.num_classes))
|
||||
self.add_state('true_negative', default=torch.zeros(self.num_classes))
|
||||
self.add_state('false_positive', default=torch.zeros(self.num_classes))
|
||||
self.add_state('false_negative', default=torch.zeros(self.num_classes))
|
||||
|
||||
allowed_average = ("micro", "macro", "weighted", None)
|
||||
if self.average not in allowed_average:
|
||||
raise ValueError('Argument `average` expected to be one of the following:'
|
||||
f' {allowed_average} but got {self.average}')
|
||||
def update(self, preds, target):
|
||||
true_positive, true_negative, false_positive, false_negative = self._update(preds, target)
|
||||
|
||||
self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
||||
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
||||
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
||||
self.true_positive += true_positive
|
||||
self.true_negative += true_negative
|
||||
self.false_positive += false_positive
|
||||
self.false_negative += false_negative
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
def _update(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
# preparing preds and targets for count
|
||||
true_pred = is_true(pred, self.device)
|
||||
false_pred = is_false(pred, self.device)
|
||||
true_target = is_true(target, self.device)
|
||||
false_target = is_false(target, self.device)
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
true_positives, predicted_positives, actual_positives = _fbeta_update(
|
||||
preds, target, self.num_classes, self.threshold, self.multilabel
|
||||
)
|
||||
|
||||
self.true_positives += true_positives
|
||||
self.predicted_positives += predicted_positives
|
||||
self.actual_positives += actual_positives
|
||||
tp = torch.sum(true_pred * true_target, dim=0)
|
||||
tn = torch.sum(false_pred * false_target, dim=0)
|
||||
fp = torch.sum(true_pred * false_target, dim=0)
|
||||
fn = torch.sum(false_pred * target, dim=0)
|
||||
return tp, tn, fp, fn
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
Computes metrics over state.
|
||||
"""
|
||||
return _fbeta_compute(self.true_positives, self.predicted_positives,
|
||||
self.actual_positives, self.beta, self.average)
|
||||
|
||||
|
||||
def _fbeta_update(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: int,
|
||||
threshold: float = 0.5,
|
||||
multilabel: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
preds, target = _input_format_classification_one_hot(
|
||||
num_classes, preds, target, threshold, multilabel
|
||||
)
|
||||
true_positives = torch.sum(preds * target, dim=1)
|
||||
predicted_positives = torch.sum(preds, dim=1)
|
||||
actual_positives = torch.sum(target, dim=1)
|
||||
return true_positives, predicted_positives, actual_positives
|
||||
|
||||
|
||||
def _fbeta_compute(
|
||||
true_positives: torch.Tensor,
|
||||
predicted_positives: torch.Tensor,
|
||||
actual_positives: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
average: str = "micro"
|
||||
) -> torch.Tensor:
|
||||
if average == "micro":
|
||||
precision = true_positives.sum().float() / predicted_positives.sum()
|
||||
recall = true_positives.sum().float() / actual_positives.sum()
|
||||
else:
|
||||
precision = true_positives.float() / predicted_positives
|
||||
recall = true_positives.float() / actual_positives
|
||||
|
||||
num = (1 + beta ** 2) * precision * recall
|
||||
denom = beta ** 2 * precision + recall
|
||||
new_num = 2 * true_positives
|
||||
new_fp = predicted_positives - true_positives
|
||||
new_fn = actual_positives - true_positives
|
||||
new_den = 2 * true_positives + new_fp + new_fn
|
||||
if new_den.sum() == 0:
|
||||
# whats is the correct return type ? TODO
|
||||
return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average)
|
||||
return class_reduce(num, denom, weights=actual_positives, class_reduction=average)
|
||||
if self.average == 'micro':
|
||||
num = 2.0 * self.true_positive.sum()
|
||||
den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum()
|
||||
if den > 0:
|
||||
return (num / den).to(self.device)
|
||||
return torch.FloatTensor([1.]).to(self.device)
|
||||
if self.average == 'macro':
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -327,3 +327,12 @@ def index(data, vocab, known_words, analyzer, unk_index, out_of_vocabulary):
|
|||
# pbar.set_description(f'[unk = {unk_count}/{knw_count}={(100.*unk_count/knw_count):.2f}%]'
|
||||
# f'[out = {out_count}/{knw_count}={(100.*out_count/knw_count):.2f}%]')
|
||||
return indexes
|
||||
|
||||
|
||||
def is_true(tensor, device):
|
||||
return torch.where(tensor == 1, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))
|
||||
|
||||
|
||||
def is_false(tensor, device):
|
||||
return torch.where(tensor == 0, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from models.pl_gru import RecurrentModel
|
|||
from models.pl_bert import BertModel
|
||||
from models.lstm_class import RNNMultilingualClassifier
|
||||
from pytorch_lightning import Trainer
|
||||
from data.datamodule import GfunDataModule, BertDataModule
|
||||
from data.datamodule import RecurrentDataModule, BertDataModule
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
import torch
|
||||
|
||||
|
@ -144,7 +144,8 @@ class RecurrentGen(ViewGen):
|
|||
# TODO: save model https://forums.pytorchlightning.ai/t/how-to-save-hparams-when-not-provided-as-argument-apparently-assigning-to-hparams-is-not-recomended/339/5
|
||||
# Problem: we are passing lPretrained to init the RecurrentModel -> incredible slow at saving (checkpoint).
|
||||
# if we do not save it is impossible to init RecurrentModel by calling RecurrentModel.load_from_checkpoint()
|
||||
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, gpus=0, n_jobs=-1, stored_path=None):
|
||||
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
|
||||
gpus=0, n_jobs=-1, stored_path=None):
|
||||
"""
|
||||
generates document embedding by means of a Gated Recurrent Units. The model can be
|
||||
initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,).
|
||||
|
@ -162,6 +163,7 @@ class RecurrentGen(ViewGen):
|
|||
self.gpus = gpus
|
||||
self.n_jobs = n_jobs
|
||||
self.stored_path = stored_path
|
||||
self.nepochs = nepochs
|
||||
|
||||
# EMBEDDINGS to be deployed
|
||||
self.pretrained = pretrained_embeddings
|
||||
|
@ -193,7 +195,8 @@ class RecurrentGen(ViewGen):
|
|||
lVocab_size=lvocab_size,
|
||||
learnable_length=learnable_length,
|
||||
drop_embedding_range=self.multilingualIndex.sup_range,
|
||||
drop_embedding_prop=0.5
|
||||
drop_embedding_prop=0.5,
|
||||
gpus=self.gpus
|
||||
)
|
||||
|
||||
def fit(self, lX, ly):
|
||||
|
@ -204,8 +207,9 @@ class RecurrentGen(ViewGen):
|
|||
:param ly:
|
||||
:return:
|
||||
"""
|
||||
recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False)
|
||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||
checkpoint_callback=False)
|
||||
|
||||
# vanilla_torch_model = torch.load(
|
||||
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
||||
|
|
Loading…
Reference in New Issue