Implemented custom micro and macro F1 in pl (cpu and gpu) + various TODO
This commit is contained in:
parent
d6eeabe6ab
commit
6ed7712979
|
|
@ -105,18 +105,16 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||||
|
# Debug settings: reducing number of samples
|
||||||
# l_train_index = {l: train[:50] for l, train in l_train_index.items()}
|
# l_train_index = {l: train[:50] for l, train in l_train_index.items()}
|
||||||
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||||
|
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||||
|
# Debug settings: reducing number of samples
|
||||||
# l_val_index = {l: train[:50] for l, train in l_val_index.items()}
|
# l_val_index = {l: train[:50] for l, train in l_val_index.items()}
|
||||||
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||||
|
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
|
|
@ -145,14 +143,21 @@ class BertDataModule(RecurrentDataModule):
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||||
|
# Debug settings: reducing number of samples
|
||||||
|
# l_train_raw = {l: train[:50] for l, train in l_train_raw.items()}
|
||||||
|
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||||
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
|
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
||||||
|
# Debug settings: reducing number of samples
|
||||||
|
# l_val_raw = {l: train[:50] for l, train in l_val_raw.items()}
|
||||||
|
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||||
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
# TODO
|
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||||
l_test_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
l_test_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def main(args):
|
||||||
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||||
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||||
data = MultilingualDataset.load(_DATASET)
|
data = MultilingualDataset.load(_DATASET)
|
||||||
# data.set_view(languages=['it'], categories=[0, 1])
|
data.set_view(languages=['it'], categories=[0, 1])
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
|
|
@ -28,9 +28,9 @@ def main(args):
|
||||||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128,
|
# gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128,
|
||||||
nepochs=100, gpus=args.gpus, n_jobs=N_JOBS)
|
# nepochs=100, gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
gFun = BertGen(multilingualIndex, batch_size=4, nepochs=10, gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
|
|
||||||
gFun.fit(lX, ly)
|
gFun.fit(lX, ly)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,25 @@
|
||||||
import torch
|
import torch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from transformers import BertForSequenceClassification, BertTokenizer, AdamW, BertConfig
|
from transformers import BertForSequenceClassification, AdamW
|
||||||
from pytorch_lightning.metrics import F1, Accuracy, Metric
|
from pytorch_lightning.metrics import Accuracy
|
||||||
|
from util.pl_metrics import CustomF1
|
||||||
|
|
||||||
|
|
||||||
class BertModel(pl.LightningModule):
|
class BertModel(pl.LightningModule):
|
||||||
|
|
||||||
def __init__(self, output_size, stored_path):
|
def __init__(self, output_size, stored_path, gpus=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.loss = torch.nn.BCEWithLogitsLoss()
|
self.loss = torch.nn.BCEWithLogitsLoss()
|
||||||
|
self.gpus = gpus
|
||||||
|
self.accuracy = Accuracy()
|
||||||
|
self.microF1_tr = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
|
self.macroF1_tr = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
self.microF1_va = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
|
self.macroF1_va = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
self.microF1_te = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
|
self.macroF1_te = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
|
||||||
if stored_path:
|
if stored_path:
|
||||||
self.bert = BertForSequenceClassification.from_pretrained(stored_path,
|
self.bert = BertForSequenceClassification.from_pretrained(stored_path,
|
||||||
num_labels=output_size,
|
num_labels=output_size,
|
||||||
|
|
@ -18,7 +28,6 @@ class BertModel(pl.LightningModule):
|
||||||
self.bert = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased',
|
self.bert = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased',
|
||||||
num_labels=output_size,
|
num_labels=output_size,
|
||||||
output_hidden_states=True)
|
output_hidden_states=True)
|
||||||
self.accuracy = Accuracy()
|
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
def forward(self, X):
|
def forward(self, X):
|
||||||
|
|
@ -31,11 +40,16 @@ class BertModel(pl.LightningModule):
|
||||||
y = y.type(torch.cuda.FloatTensor)
|
y = y.type(torch.cuda.FloatTensor)
|
||||||
logits, _ = self.forward(X)
|
logits, _ = self.forward(X)
|
||||||
loss = self.loss(logits, y)
|
loss = self.loss(logits, y)
|
||||||
|
# Squashing logits through Sigmoid in order to get confidence score
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
accuracy = self.accuracy(predictions, y)
|
accuracy = self.accuracy(predictions, y)
|
||||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
microF1 = self.microF1_tr(predictions, y)
|
||||||
|
macroF1 = self.macroF1_tr(predictions, y)
|
||||||
|
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
return loss
|
self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
|
self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
|
return {'loss': loss}
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, val_batch, batch_idx):
|
||||||
X, y, _, batch_langs = val_batch
|
X, y, _, batch_langs = val_batch
|
||||||
|
|
@ -45,9 +59,29 @@ class BertModel(pl.LightningModule):
|
||||||
loss = self.loss(logits, y)
|
loss = self.loss(logits, y)
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
accuracy = self.accuracy(predictions, y)
|
accuracy = self.accuracy(predictions, y)
|
||||||
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
microF1 = self.microF1_va(predictions, y)
|
||||||
|
macroF1 = self.macroF1_va(predictions, y)
|
||||||
|
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
return
|
self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
return {'loss': loss}
|
||||||
|
|
||||||
|
# def test_step(self, test_batch, batch_idx):
|
||||||
|
# lX, ly = test_batch
|
||||||
|
# logits = self.forward(lX)
|
||||||
|
# _ly = []
|
||||||
|
# for lang in sorted(lX.keys()):
|
||||||
|
# _ly.append(ly[lang])
|
||||||
|
# ly = torch.cat(_ly, dim=0)
|
||||||
|
# predictions = torch.sigmoid(logits) > 0.5
|
||||||
|
# accuracy = self.accuracy(predictions, ly)
|
||||||
|
# microF1 = self.microF1_te(predictions, ly)
|
||||||
|
# macroF1 = self.macroF1_te(predictions, ly)
|
||||||
|
# self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
|
# self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
|
# self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
|
# return
|
||||||
|
|
||||||
def configure_optimizers(self, lr=3e-5, weight_decay=0.01):
|
def configure_optimizers(self, lr=3e-5, weight_decay=0.01):
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ['bias', 'LayerNorm.weight']
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,19 @@
|
||||||
# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
|
# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AdamW
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
from transformers import AdamW
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.metrics import F1, Accuracy
|
from pytorch_lightning.metrics import F1, Accuracy
|
||||||
from torch.optim.lr_scheduler import StepLR
|
|
||||||
from models.helpers import init_embeddings
|
from models.helpers import init_embeddings
|
||||||
from util.pl_metrics import CustomF1
|
from util.pl_metrics import CustomF1
|
||||||
from util.evaluation import evaluate
|
from util.evaluation import evaluate
|
||||||
|
|
||||||
# TODO: it should also be possible to compute metrics independently for each language!
|
# TODO: it should also be possible to compute metrics independently for each language!
|
||||||
|
|
||||||
|
|
||||||
class RecurrentModel(pl.LightningModule):
|
class RecurrentModel(pl.LightningModule):
|
||||||
"""
|
"""
|
||||||
Check out for logging insight https://www.learnopencv.com/tensorboard-with-pytorch-lightning/
|
Check out for logging insight https://www.learnopencv.com/tensorboard-with-pytorch-lightning/
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ class MultilingualIndex:
|
||||||
|
|
||||||
def l_test_raw(self):
|
def l_test_raw(self):
|
||||||
print('TODO: implement MultilingualIndex method to return RAW test data!')
|
print('TODO: implement MultilingualIndex method to return RAW test data!')
|
||||||
return NotImplementedError
|
return {l: index.test_raw for l, index in self.l_index.items()}
|
||||||
|
|
||||||
def l_devel_index(self):
|
def l_devel_index(self):
|
||||||
return {l: index.devel_index for l, index in self.l_index.items()}
|
return {l: index.devel_index for l, index in self.l_index.items()}
|
||||||
|
|
|
||||||
|
|
@ -231,9 +231,10 @@ class RecurrentGen(ViewGen):
|
||||||
|
|
||||||
class BertGen(ViewGen):
|
class BertGen(ViewGen):
|
||||||
|
|
||||||
def __init__(self, multilingualIndex, batch_size=128, gpus=0, n_jobs=-1, stored_path=None):
|
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, stored_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.multilingualIndex = multilingualIndex
|
self.multilingualIndex = multilingualIndex
|
||||||
|
self.nepochs = nepochs
|
||||||
self.gpus = gpus
|
self.gpus = gpus
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
|
|
@ -244,11 +245,12 @@ class BertGen(ViewGen):
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
output_size = self.multilingualIndex.get_target_dim()
|
output_size = self.multilingualIndex.get_target_dim()
|
||||||
return BertModel(output_size=output_size, stored_path=self.stored_path)
|
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||||
trainer = Trainer(default_root_dir='checkpoints/bert/', gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger)
|
trainer = Trainer(default_root_dir='checkpoints/bert/', gradient_clip_val=1e-1, max_epochs=self.nepochs,
|
||||||
|
gpus=self.gpus, logger=self.logger, checkpoint_callback=False)
|
||||||
trainer.fit(self.model, bertDataModule)
|
trainer.fit(self.model, bertDataModule)
|
||||||
# trainer.test(self.model, bertDataModule)
|
# trainer.test(self.model, bertDataModule)
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue