Implemented custom micro and macro F1 in pl (cpu and gpu) + various TODO
This commit is contained in:
parent
d6eeabe6ab
commit
6ed7712979
|
|
@ -105,18 +105,16 @@ class RecurrentDataModule(pl.LightningDataModule):
|
|||
def setup(self, stage=None):
|
||||
if stage == 'fit' or stage is None:
|
||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||
|
||||
# Debug settings: reducing number of samples
|
||||
# l_train_index = {l: train[:50] for l, train in l_train_index.items()}
|
||||
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||
|
||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||
|
||||
# Debug settings: reducing number of samples
|
||||
# l_val_index = {l: train[:50] for l, train in l_val_index.items()}
|
||||
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
if stage == 'test' or stage is None:
|
||||
|
|
@ -145,14 +143,21 @@ class BertDataModule(RecurrentDataModule):
|
|||
def setup(self, stage=None):
|
||||
if stage == 'fit' or stage is None:
|
||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||
# Debug settings: reducing number of samples
|
||||
# l_train_raw = {l: train[:50] for l, train in l_train_raw.items()}
|
||||
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
|
||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
||||
# Debug settings: reducing number of samples
|
||||
# l_val_raw = {l: train[:50] for l, train in l_val_raw.items()}
|
||||
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
# TODO
|
||||
|
||||
if stage == 'test' or stage is None:
|
||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||
l_test_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def main(args):
|
|||
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||
data = MultilingualDataset.load(_DATASET)
|
||||
# data.set_view(languages=['it'], categories=[0, 1])
|
||||
data.set_view(languages=['it'], categories=[0, 1])
|
||||
lX, ly = data.training()
|
||||
lXte, lyte = data.test()
|
||||
|
||||
|
|
@ -28,9 +28,9 @@ def main(args):
|
|||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128,
|
||||
nepochs=100, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
||||
# gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128,
|
||||
# nepochs=100, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
gFun = BertGen(multilingualIndex, batch_size=4, nepochs=10, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
|
||||
gFun.fit(lX, ly)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,25 @@
|
|||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from transformers import BertForSequenceClassification, BertTokenizer, AdamW, BertConfig
|
||||
from pytorch_lightning.metrics import F1, Accuracy, Metric
|
||||
from transformers import BertForSequenceClassification, AdamW
|
||||
from pytorch_lightning.metrics import Accuracy
|
||||
from util.pl_metrics import CustomF1
|
||||
|
||||
|
||||
class BertModel(pl.LightningModule):
|
||||
|
||||
def __init__(self, output_size, stored_path):
|
||||
def __init__(self, output_size, stored_path, gpus=None):
|
||||
super().__init__()
|
||||
self.loss = torch.nn.BCEWithLogitsLoss()
|
||||
self.gpus = gpus
|
||||
self.accuracy = Accuracy()
|
||||
self.microF1_tr = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||
self.macroF1_tr = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||
self.microF1_va = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||
self.macroF1_va = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||
self.microF1_te = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||
self.macroF1_te = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||
|
||||
if stored_path:
|
||||
self.bert = BertForSequenceClassification.from_pretrained(stored_path,
|
||||
num_labels=output_size,
|
||||
|
|
@ -18,7 +28,6 @@ class BertModel(pl.LightningModule):
|
|||
self.bert = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased',
|
||||
num_labels=output_size,
|
||||
output_hidden_states=True)
|
||||
self.accuracy = Accuracy()
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, X):
|
||||
|
|
@ -31,11 +40,16 @@ class BertModel(pl.LightningModule):
|
|||
y = y.type(torch.cuda.FloatTensor)
|
||||
logits, _ = self.forward(X)
|
||||
loss = self.loss(logits, y)
|
||||
# Squashing logits through Sigmoid in order to get confidence score
|
||||
predictions = torch.sigmoid(logits) > 0.5
|
||||
accuracy = self.accuracy(predictions, y)
|
||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
microF1 = self.microF1_tr(predictions, y)
|
||||
macroF1 = self.macroF1_tr(predictions, y)
|
||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
return loss
|
||||
self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
return {'loss': loss}
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
X, y, _, batch_langs = val_batch
|
||||
|
|
@ -45,9 +59,29 @@ class BertModel(pl.LightningModule):
|
|||
loss = self.loss(logits, y)
|
||||
predictions = torch.sigmoid(logits) > 0.5
|
||||
accuracy = self.accuracy(predictions, y)
|
||||
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
microF1 = self.microF1_va(predictions, y)
|
||||
macroF1 = self.macroF1_va(predictions, y)
|
||||
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
return
|
||||
self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
return {'loss': loss}
|
||||
|
||||
# def test_step(self, test_batch, batch_idx):
|
||||
# lX, ly = test_batch
|
||||
# logits = self.forward(lX)
|
||||
# _ly = []
|
||||
# for lang in sorted(lX.keys()):
|
||||
# _ly.append(ly[lang])
|
||||
# ly = torch.cat(_ly, dim=0)
|
||||
# predictions = torch.sigmoid(logits) > 0.5
|
||||
# accuracy = self.accuracy(predictions, ly)
|
||||
# microF1 = self.microF1_te(predictions, ly)
|
||||
# macroF1 = self.macroF1_te(predictions, ly)
|
||||
# self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
# self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
# self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
# return
|
||||
|
||||
def configure_optimizers(self, lr=3e-5, weight_decay=0.01):
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AdamW
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from transformers import AdamW
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.metrics import F1, Accuracy
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from models.helpers import init_embeddings
|
||||
from util.pl_metrics import CustomF1
|
||||
from util.evaluation import evaluate
|
||||
|
||||
# TODO: it should also be possible to compute metrics independently for each language!
|
||||
|
||||
|
||||
class RecurrentModel(pl.LightningModule):
|
||||
"""
|
||||
Check out for logging insight https://www.learnopencv.com/tensorboard-with-pytorch-lightning/
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class MultilingualIndex:
|
|||
|
||||
def l_test_raw(self):
|
||||
print('TODO: implement MultilingualIndex method to return RAW test data!')
|
||||
return NotImplementedError
|
||||
return {l: index.test_raw for l, index in self.l_index.items()}
|
||||
|
||||
def l_devel_index(self):
|
||||
return {l: index.devel_index for l, index in self.l_index.items()}
|
||||
|
|
|
|||
|
|
@ -231,9 +231,10 @@ class RecurrentGen(ViewGen):
|
|||
|
||||
class BertGen(ViewGen):
|
||||
|
||||
def __init__(self, multilingualIndex, batch_size=128, gpus=0, n_jobs=-1, stored_path=None):
|
||||
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, stored_path=None):
|
||||
super().__init__()
|
||||
self.multilingualIndex = multilingualIndex
|
||||
self.nepochs = nepochs
|
||||
self.gpus = gpus
|
||||
self.batch_size = batch_size
|
||||
self.n_jobs = n_jobs
|
||||
|
|
@ -244,11 +245,12 @@ class BertGen(ViewGen):
|
|||
|
||||
def _init_model(self):
|
||||
output_size = self.multilingualIndex.get_target_dim()
|
||||
return BertModel(output_size=output_size, stored_path=self.stored_path)
|
||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||
|
||||
def fit(self, lX, ly):
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||
trainer = Trainer(default_root_dir='checkpoints/bert/', gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger)
|
||||
trainer = Trainer(default_root_dir='checkpoints/bert/', gradient_clip_val=1e-1, max_epochs=self.nepochs,
|
||||
gpus=self.gpus, logger=self.logger, checkpoint_callback=False)
|
||||
trainer.fit(self.model, bertDataModule)
|
||||
# trainer.test(self.model, bertDataModule)
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue