Implemented custom micro and macro F1 in pl (cpu and gpu) + various TODO

This commit is contained in:
andrea 2021-01-20 15:13:39 +01:00
parent d6eeabe6ab
commit 6ed7712979
6 changed files with 65 additions and 23 deletions

View File

@ -105,18 +105,16 @@ class RecurrentDataModule(pl.LightningDataModule):
def setup(self, stage=None): def setup(self, stage=None):
if stage == 'fit' or stage is None: if stage == 'fit' or stage is None:
l_train_index, l_train_target = self.multilingualIndex.l_train() l_train_index, l_train_target = self.multilingualIndex.l_train()
# Debug settings: reducing number of samples
# l_train_index = {l: train[:50] for l, train in l_train_index.items()} # l_train_index = {l: train[:50] for l, train in l_train_index.items()}
# l_train_target = {l: target[:50] for l, target in l_train_target.items()} # l_train_target = {l: target[:50] for l, target in l_train_target.items()}
self.training_dataset = RecurrentDataset(l_train_index, l_train_target, self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
l_val_index, l_val_target = self.multilingualIndex.l_val() l_val_index, l_val_target = self.multilingualIndex.l_val()
# Debug settings: reducing number of samples
# l_val_index = {l: train[:50] for l, train in l_val_index.items()} # l_val_index = {l: train[:50] for l, train in l_val_index.items()}
# l_val_target = {l: target[:50] for l, target in l_val_target.items()} # l_val_target = {l: target[:50] for l, target in l_val_target.items()}
self.val_dataset = RecurrentDataset(l_val_index, l_val_target, self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
if stage == 'test' or stage is None: if stage == 'test' or stage is None:
@ -145,14 +143,21 @@ class BertDataModule(RecurrentDataModule):
def setup(self, stage=None): def setup(self, stage=None):
if stage == 'fit' or stage is None: if stage == 'fit' or stage is None:
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw() l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
# Debug settings: reducing number of samples
# l_train_raw = {l: train[:50] for l, train in l_train_raw.items()}
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len) l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
self.training_dataset = RecurrentDataset(l_train_index, l_train_target, self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw() l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
# Debug settings: reducing number of samples
# l_val_raw = {l: train[:50] for l, train in l_val_raw.items()}
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len) l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
self.val_dataset = RecurrentDataset(l_val_index, l_val_target, self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
# TODO
if stage == 'test' or stage is None: if stage == 'test' or stage is None:
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw() l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
l_test_index = self.tokenize(l_val_raw, max_len=self.max_len) l_test_index = self.tokenize(l_val_raw, max_len=self.max_len)

View File

@ -15,7 +15,7 @@ def main(args):
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' _DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings' EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
data = MultilingualDataset.load(_DATASET) data = MultilingualDataset.load(_DATASET)
# data.set_view(languages=['it'], categories=[0, 1]) data.set_view(languages=['it'], categories=[0, 1])
lX, ly = data.training() lX, ly = data.training()
lXte, lyte = data.test() lXte, lyte = data.test()
@ -28,9 +28,9 @@ def main(args):
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
# gFun = WordClassGen(n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS)
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128, # gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128,
nepochs=100, gpus=args.gpus, n_jobs=N_JOBS) # nepochs=100, gpus=args.gpus, n_jobs=N_JOBS)
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS) gFun = BertGen(multilingualIndex, batch_size=4, nepochs=10, gpus=args.gpus, n_jobs=N_JOBS)
gFun.fit(lX, ly) gFun.fit(lX, ly)

View File

@ -1,15 +1,25 @@
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from transformers import BertForSequenceClassification, BertTokenizer, AdamW, BertConfig from transformers import BertForSequenceClassification, AdamW
from pytorch_lightning.metrics import F1, Accuracy, Metric from pytorch_lightning.metrics import Accuracy
from util.pl_metrics import CustomF1
class BertModel(pl.LightningModule): class BertModel(pl.LightningModule):
def __init__(self, output_size, stored_path): def __init__(self, output_size, stored_path, gpus=None):
super().__init__() super().__init__()
self.loss = torch.nn.BCEWithLogitsLoss() self.loss = torch.nn.BCEWithLogitsLoss()
self.gpus = gpus
self.accuracy = Accuracy()
self.microF1_tr = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
self.macroF1_tr = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
self.microF1_va = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
self.macroF1_va = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
self.microF1_te = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
self.macroF1_te = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
if stored_path: if stored_path:
self.bert = BertForSequenceClassification.from_pretrained(stored_path, self.bert = BertForSequenceClassification.from_pretrained(stored_path,
num_labels=output_size, num_labels=output_size,
@ -18,7 +28,6 @@ class BertModel(pl.LightningModule):
self.bert = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', self.bert = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased',
num_labels=output_size, num_labels=output_size,
output_hidden_states=True) output_hidden_states=True)
self.accuracy = Accuracy()
self.save_hyperparameters() self.save_hyperparameters()
def forward(self, X): def forward(self, X):
@ -31,11 +40,16 @@ class BertModel(pl.LightningModule):
y = y.type(torch.cuda.FloatTensor) y = y.type(torch.cuda.FloatTensor)
logits, _ = self.forward(X) logits, _ = self.forward(X)
loss = self.loss(logits, y) loss = self.loss(logits, y)
# Squashing logits through Sigmoid in order to get confidence score
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
accuracy = self.accuracy(predictions, y) accuracy = self.accuracy(predictions, y)
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) microF1 = self.microF1_tr(predictions, y)
macroF1 = self.macroF1_tr(predictions, y)
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
return loss self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
return {'loss': loss}
def validation_step(self, val_batch, batch_idx): def validation_step(self, val_batch, batch_idx):
X, y, _, batch_langs = val_batch X, y, _, batch_langs = val_batch
@ -45,9 +59,29 @@ class BertModel(pl.LightningModule):
loss = self.loss(logits, y) loss = self.loss(logits, y)
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
accuracy = self.accuracy(predictions, y) accuracy = self.accuracy(predictions, y)
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) microF1 = self.microF1_va(predictions, y)
macroF1 = self.macroF1_va(predictions, y)
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
return self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return {'loss': loss}
# def test_step(self, test_batch, batch_idx):
# lX, ly = test_batch
# logits = self.forward(lX)
# _ly = []
# for lang in sorted(lX.keys()):
# _ly.append(ly[lang])
# ly = torch.cat(_ly, dim=0)
# predictions = torch.sigmoid(logits) > 0.5
# accuracy = self.accuracy(predictions, ly)
# microF1 = self.microF1_te(predictions, ly)
# macroF1 = self.macroF1_te(predictions, ly)
# self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
# self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
# self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
# return
def configure_optimizers(self, lr=3e-5, weight_decay=0.01): def configure_optimizers(self, lr=3e-5, weight_decay=0.01):
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.weight']

View File

@ -1,18 +1,19 @@
# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html # Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
import torch import torch
from torch import nn from torch import nn
from transformers import AdamW
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from transformers import AdamW
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.metrics import F1, Accuracy from pytorch_lightning.metrics import F1, Accuracy
from torch.optim.lr_scheduler import StepLR
from models.helpers import init_embeddings from models.helpers import init_embeddings
from util.pl_metrics import CustomF1 from util.pl_metrics import CustomF1
from util.evaluation import evaluate from util.evaluation import evaluate
# TODO: it should also be possible to compute metrics independently for each language! # TODO: it should also be possible to compute metrics independently for each language!
class RecurrentModel(pl.LightningModule): class RecurrentModel(pl.LightningModule):
""" """
Check out for logging insight https://www.learnopencv.com/tensorboard-with-pytorch-lightning/ Check out for logging insight https://www.learnopencv.com/tensorboard-with-pytorch-lightning/

View File

@ -171,7 +171,7 @@ class MultilingualIndex:
def l_test_raw(self): def l_test_raw(self):
print('TODO: implement MultilingualIndex method to return RAW test data!') print('TODO: implement MultilingualIndex method to return RAW test data!')
return NotImplementedError return {l: index.test_raw for l, index in self.l_index.items()}
def l_devel_index(self): def l_devel_index(self):
return {l: index.devel_index for l, index in self.l_index.items()} return {l: index.devel_index for l, index in self.l_index.items()}

View File

@ -231,9 +231,10 @@ class RecurrentGen(ViewGen):
class BertGen(ViewGen): class BertGen(ViewGen):
def __init__(self, multilingualIndex, batch_size=128, gpus=0, n_jobs=-1, stored_path=None): def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, stored_path=None):
super().__init__() super().__init__()
self.multilingualIndex = multilingualIndex self.multilingualIndex = multilingualIndex
self.nepochs = nepochs
self.gpus = gpus self.gpus = gpus
self.batch_size = batch_size self.batch_size = batch_size
self.n_jobs = n_jobs self.n_jobs = n_jobs
@ -244,11 +245,12 @@ class BertGen(ViewGen):
def _init_model(self): def _init_model(self):
output_size = self.multilingualIndex.get_target_dim() output_size = self.multilingualIndex.get_target_dim()
return BertModel(output_size=output_size, stored_path=self.stored_path) return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
def fit(self, lX, ly): def fit(self, lX, ly):
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
trainer = Trainer(default_root_dir='checkpoints/bert/', gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger) trainer = Trainer(default_root_dir='checkpoints/bert/', gradient_clip_val=1e-1, max_epochs=self.nepochs,
gpus=self.gpus, logger=self.logger, checkpoint_callback=False)
trainer.fit(self.model, bertDataModule) trainer.fit(self.model, bertDataModule)
# trainer.test(self.model, bertDataModule) # trainer.test(self.model, bertDataModule)
pass pass