Implemented custom micro F1 in pl (cpu and gpu)
This commit is contained in:
parent
294d7c3be7
commit
8dbe48ff7a
|
|
@ -88,7 +88,7 @@ class RecurrentDataset(Dataset):
|
||||||
return index_list
|
return index_list
|
||||||
|
|
||||||
|
|
||||||
class GfunDataModule(pl.LightningDataModule):
|
class RecurrentDataModule(pl.LightningDataModule):
|
||||||
def __init__(self, multilingualIndex, batchsize=64):
|
def __init__(self, multilingualIndex, batchsize=64):
|
||||||
"""
|
"""
|
||||||
Pytorch-lightning DataModule: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
Pytorch-lightning DataModule: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
|
|
@ -105,9 +105,18 @@ class GfunDataModule(pl.LightningDataModule):
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||||
|
|
||||||
|
# l_train_index = {l: train[:50] for l, train in l_train_index.items()}
|
||||||
|
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||||
|
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||||
|
|
||||||
|
# l_val_index = {l: train[:50] for l, train in l_val_index.items()}
|
||||||
|
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||||
|
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
|
|
@ -128,7 +137,7 @@ class GfunDataModule(pl.LightningDataModule):
|
||||||
collate_fn=self.test_dataset.collate_fn)
|
collate_fn=self.test_dataset.collate_fn)
|
||||||
|
|
||||||
|
|
||||||
class BertDataModule(GfunDataModule):
|
class BertDataModule(RecurrentDataModule):
|
||||||
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
|
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
|
||||||
super().__init__(multilingualIndex, batchsize)
|
super().__init__(multilingualIndex, batchsize)
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def main(args):
|
||||||
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||||
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||||
data = MultilingualDataset.load(_DATASET)
|
data = MultilingualDataset.load(_DATASET)
|
||||||
data.set_view(languages=['it'], categories=[0,1])
|
data.set_view(languages=['it'], categories=[0, 1])
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
|
|
@ -28,7 +28,8 @@ def main(args):
|
||||||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS)
|
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=5, nepochs=100,
|
||||||
|
gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
||||||
|
|
||||||
gFun.fit(lX, ly)
|
gFun.fit(lX, ly)
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,10 @@ from transformers import AdamW
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.metrics import F1, Accuracy, Metric
|
from pytorch_lightning.metrics import Metric, F1, Accuracy
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from typing import Any, Optional, Tuple
|
|
||||||
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
|
|
||||||
from models.helpers import init_embeddings
|
from models.helpers import init_embeddings
|
||||||
import numpy as np
|
from util.common import is_true, is_false
|
||||||
from util.evaluation import evaluate
|
from util.evaluation import evaluate
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,8 +18,9 @@ class RecurrentModel(pl.LightningModule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lPretrained, langs, output_size, hidden_size, lVocab_size, learnable_length,
|
def __init__(self, lPretrained, langs, output_size, hidden_size, lVocab_size, learnable_length,
|
||||||
drop_embedding_range, drop_embedding_prop, lMuse_debug=None, multilingual_index_debug=None):
|
drop_embedding_range, drop_embedding_prop, gpus=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.gpus = gpus
|
||||||
self.langs = langs
|
self.langs = langs
|
||||||
self.lVocab_size = lVocab_size
|
self.lVocab_size = lVocab_size
|
||||||
self.learnable_length = learnable_length
|
self.learnable_length = learnable_length
|
||||||
|
|
@ -33,7 +32,7 @@ class RecurrentModel(pl.LightningModule):
|
||||||
self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro')
|
self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro')
|
||||||
self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro')
|
self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro')
|
||||||
self.accuracy = Accuracy()
|
self.accuracy = Accuracy()
|
||||||
self.customMetrics = CustomMetrics(num_classes=output_size, multilabel=True, average='micro')
|
self.customMetrics = CustomF1(num_classes=output_size, device=self.gpus)
|
||||||
|
|
||||||
self.lPretrained_embeddings = nn.ModuleDict()
|
self.lPretrained_embeddings = nn.ModuleDict()
|
||||||
self.lLearnable_embeddings = nn.ModuleDict()
|
self.lLearnable_embeddings = nn.ModuleDict()
|
||||||
|
|
@ -42,10 +41,6 @@ class RecurrentModel(pl.LightningModule):
|
||||||
self.n_directions = 1
|
self.n_directions = 1
|
||||||
self.dropout = nn.Dropout(0.6)
|
self.dropout = nn.Dropout(0.6)
|
||||||
|
|
||||||
# TODO: debug setting
|
|
||||||
self.lMuse = lMuse_debug
|
|
||||||
self.multilingual_index_debug = multilingual_index_debug
|
|
||||||
|
|
||||||
lstm_out = 256
|
lstm_out = 256
|
||||||
ff1 = 512
|
ff1 = 512
|
||||||
ff2 = 256
|
ff2 = 256
|
||||||
|
|
@ -111,7 +106,7 @@ class RecurrentModel(pl.LightningModule):
|
||||||
custom = self.customMetrics(predictions, ly)
|
custom = self.customMetrics(predictions, ly)
|
||||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('custom', custom, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, val_batch, batch_idx):
|
||||||
|
|
@ -139,7 +134,6 @@ class RecurrentModel(pl.LightningModule):
|
||||||
accuracy = self.accuracy(predictions, ly)
|
accuracy = self.accuracy(predictions, ly)
|
||||||
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
return
|
return
|
||||||
# return {'pred': predictions, 'target': ly}
|
|
||||||
|
|
||||||
def embed(self, X, lang):
|
def embed(self, X, lang):
|
||||||
input_list = []
|
input_list = []
|
||||||
|
|
@ -166,98 +160,56 @@ class RecurrentModel(pl.LightningModule):
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
|
||||||
class CustomMetrics(Metric):
|
class CustomF1(Metric):
|
||||||
def __init__(
|
def __init__(self, num_classes, device, average='micro'):
|
||||||
self,
|
"""
|
||||||
num_classes: int,
|
Custom F1 metric.
|
||||||
beta: float = 1.0,
|
Scikit learn provides a full set of evaluation metrics, but they treat special cases differently.
|
||||||
threshold: float = 0.5,
|
I.e., when the number of true positives, false positives, and false negatives amount to 0, all
|
||||||
average: str = "micro",
|
affected metrics (precision, recall, and thus f1) output 0 in Scikit learn.
|
||||||
multilabel: bool = False,
|
We adhere to the common practice of outputting 1 in this case since the classifier has correctly
|
||||||
compute_on_step: bool = True,
|
classified all examples as negatives.
|
||||||
dist_sync_on_step: bool = False,
|
:param num_classes:
|
||||||
process_group: Optional[Any] = None,
|
:param device:
|
||||||
):
|
:param average:
|
||||||
super().__init__(
|
"""
|
||||||
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group,
|
super().__init__()
|
||||||
)
|
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.beta = beta
|
|
||||||
self.threshold = threshold
|
|
||||||
self.average = average
|
self.average = average
|
||||||
self.multilabel = multilabel
|
self.device = 'cuda' if device else 'cpu'
|
||||||
|
self.add_state('true_positive', default=torch.zeros(self.num_classes))
|
||||||
|
self.add_state('true_negative', default=torch.zeros(self.num_classes))
|
||||||
|
self.add_state('false_positive', default=torch.zeros(self.num_classes))
|
||||||
|
self.add_state('false_negative', default=torch.zeros(self.num_classes))
|
||||||
|
|
||||||
allowed_average = ("micro", "macro", "weighted", None)
|
def update(self, preds, target):
|
||||||
if self.average not in allowed_average:
|
true_positive, true_negative, false_positive, false_negative = self._update(preds, target)
|
||||||
raise ValueError('Argument `average` expected to be one of the following:'
|
|
||||||
f' {allowed_average} but got {self.average}')
|
|
||||||
|
|
||||||
self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
self.true_positive += true_positive
|
||||||
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
self.true_negative += true_negative
|
||||||
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
self.false_positive += false_positive
|
||||||
|
self.false_negative += false_negative
|
||||||
|
|
||||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
def _update(self, pred, target):
|
||||||
"""
|
assert pred.shape == target.shape
|
||||||
Update state with predictions and targets.
|
# preparing preds and targets for count
|
||||||
|
true_pred = is_true(pred, self.device)
|
||||||
|
false_pred = is_false(pred, self.device)
|
||||||
|
true_target = is_true(target, self.device)
|
||||||
|
false_target = is_false(target, self.device)
|
||||||
|
|
||||||
Args:
|
tp = torch.sum(true_pred * true_target, dim=0)
|
||||||
preds: Predictions from model
|
tn = torch.sum(false_pred * false_target, dim=0)
|
||||||
target: Ground truth values
|
fp = torch.sum(true_pred * false_target, dim=0)
|
||||||
"""
|
fn = torch.sum(false_pred * target, dim=0)
|
||||||
true_positives, predicted_positives, actual_positives = _fbeta_update(
|
return tp, tn, fp, fn
|
||||||
preds, target, self.num_classes, self.threshold, self.multilabel
|
|
||||||
)
|
|
||||||
|
|
||||||
self.true_positives += true_positives
|
|
||||||
self.predicted_positives += predicted_positives
|
|
||||||
self.actual_positives += actual_positives
|
|
||||||
|
|
||||||
def compute(self):
|
def compute(self):
|
||||||
"""
|
if self.average == 'micro':
|
||||||
Computes metrics over state.
|
num = 2.0 * self.true_positive.sum()
|
||||||
"""
|
den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum()
|
||||||
return _fbeta_compute(self.true_positives, self.predicted_positives,
|
if den > 0:
|
||||||
self.actual_positives, self.beta, self.average)
|
return (num / den).to(self.device)
|
||||||
|
return torch.FloatTensor([1.]).to(self.device)
|
||||||
|
if self.average == 'macro':
|
||||||
def _fbeta_update(
|
raise NotImplementedError
|
||||||
preds: torch.Tensor,
|
|
||||||
target: torch.Tensor,
|
|
||||||
num_classes: int,
|
|
||||||
threshold: float = 0.5,
|
|
||||||
multilabel: bool = False
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
preds, target = _input_format_classification_one_hot(
|
|
||||||
num_classes, preds, target, threshold, multilabel
|
|
||||||
)
|
|
||||||
true_positives = torch.sum(preds * target, dim=1)
|
|
||||||
predicted_positives = torch.sum(preds, dim=1)
|
|
||||||
actual_positives = torch.sum(target, dim=1)
|
|
||||||
return true_positives, predicted_positives, actual_positives
|
|
||||||
|
|
||||||
|
|
||||||
def _fbeta_compute(
|
|
||||||
true_positives: torch.Tensor,
|
|
||||||
predicted_positives: torch.Tensor,
|
|
||||||
actual_positives: torch.Tensor,
|
|
||||||
beta: float = 1.0,
|
|
||||||
average: str = "micro"
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if average == "micro":
|
|
||||||
precision = true_positives.sum().float() / predicted_positives.sum()
|
|
||||||
recall = true_positives.sum().float() / actual_positives.sum()
|
|
||||||
else:
|
|
||||||
precision = true_positives.float() / predicted_positives
|
|
||||||
recall = true_positives.float() / actual_positives
|
|
||||||
|
|
||||||
num = (1 + beta ** 2) * precision * recall
|
|
||||||
denom = beta ** 2 * precision + recall
|
|
||||||
new_num = 2 * true_positives
|
|
||||||
new_fp = predicted_positives - true_positives
|
|
||||||
new_fn = actual_positives - true_positives
|
|
||||||
new_den = 2 * true_positives + new_fp + new_fn
|
|
||||||
if new_den.sum() == 0:
|
|
||||||
# whats is the correct return type ? TODO
|
|
||||||
return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average)
|
|
||||||
return class_reduce(num, denom, weights=actual_positives, class_reduction=average)
|
|
||||||
|
|
|
||||||
|
|
@ -327,3 +327,12 @@ def index(data, vocab, known_words, analyzer, unk_index, out_of_vocabulary):
|
||||||
# pbar.set_description(f'[unk = {unk_count}/{knw_count}={(100.*unk_count/knw_count):.2f}%]'
|
# pbar.set_description(f'[unk = {unk_count}/{knw_count}={(100.*unk_count/knw_count):.2f}%]'
|
||||||
# f'[out = {out_count}/{knw_count}={(100.*out_count/knw_count):.2f}%]')
|
# f'[out = {out_count}/{knw_count}={(100.*out_count/knw_count):.2f}%]')
|
||||||
return indexes
|
return indexes
|
||||||
|
|
||||||
|
|
||||||
|
def is_true(tensor, device):
|
||||||
|
return torch.where(tensor == 1, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))
|
||||||
|
|
||||||
|
|
||||||
|
def is_false(tensor, device):
|
||||||
|
return torch.where(tensor == 0, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from models.pl_gru import RecurrentModel
|
||||||
from models.pl_bert import BertModel
|
from models.pl_bert import BertModel
|
||||||
from models.lstm_class import RNNMultilingualClassifier
|
from models.lstm_class import RNNMultilingualClassifier
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from data.datamodule import GfunDataModule, BertDataModule
|
from data.datamodule import RecurrentDataModule, BertDataModule
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -144,7 +144,8 @@ class RecurrentGen(ViewGen):
|
||||||
# TODO: save model https://forums.pytorchlightning.ai/t/how-to-save-hparams-when-not-provided-as-argument-apparently-assigning-to-hparams-is-not-recomended/339/5
|
# TODO: save model https://forums.pytorchlightning.ai/t/how-to-save-hparams-when-not-provided-as-argument-apparently-assigning-to-hparams-is-not-recomended/339/5
|
||||||
# Problem: we are passing lPretrained to init the RecurrentModel -> incredible slow at saving (checkpoint).
|
# Problem: we are passing lPretrained to init the RecurrentModel -> incredible slow at saving (checkpoint).
|
||||||
# if we do not save it is impossible to init RecurrentModel by calling RecurrentModel.load_from_checkpoint()
|
# if we do not save it is impossible to init RecurrentModel by calling RecurrentModel.load_from_checkpoint()
|
||||||
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, gpus=0, n_jobs=-1, stored_path=None):
|
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
|
||||||
|
gpus=0, n_jobs=-1, stored_path=None):
|
||||||
"""
|
"""
|
||||||
generates document embedding by means of a Gated Recurrent Units. The model can be
|
generates document embedding by means of a Gated Recurrent Units. The model can be
|
||||||
initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,).
|
initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,).
|
||||||
|
|
@ -162,6 +163,7 @@ class RecurrentGen(ViewGen):
|
||||||
self.gpus = gpus
|
self.gpus = gpus
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.stored_path = stored_path
|
self.stored_path = stored_path
|
||||||
|
self.nepochs = nepochs
|
||||||
|
|
||||||
# EMBEDDINGS to be deployed
|
# EMBEDDINGS to be deployed
|
||||||
self.pretrained = pretrained_embeddings
|
self.pretrained = pretrained_embeddings
|
||||||
|
|
@ -193,7 +195,8 @@ class RecurrentGen(ViewGen):
|
||||||
lVocab_size=lvocab_size,
|
lVocab_size=lvocab_size,
|
||||||
learnable_length=learnable_length,
|
learnable_length=learnable_length,
|
||||||
drop_embedding_range=self.multilingualIndex.sup_range,
|
drop_embedding_range=self.multilingualIndex.sup_range,
|
||||||
drop_embedding_prop=0.5
|
drop_embedding_prop=0.5,
|
||||||
|
gpus=self.gpus
|
||||||
)
|
)
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
|
|
@ -204,8 +207,9 @@ class RecurrentGen(ViewGen):
|
||||||
:param ly:
|
:param ly:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False)
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||||
|
checkpoint_callback=False)
|
||||||
|
|
||||||
# vanilla_torch_model = torch.load(
|
# vanilla_torch_model = torch.load(
|
||||||
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue