Implemented micro and macro K in pl (cpu and gpu)

This commit is contained in:
andrea 2021-01-21 10:13:03 +01:00
parent 6ed7712979
commit 5ce1203942
3 changed files with 109 additions and 34 deletions

View File

@ -28,9 +28,9 @@ def main(args):
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
# gFun = WordClassGen(n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS)
# gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128, gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128,
# nepochs=100, gpus=args.gpus, n_jobs=N_JOBS) nepochs=100, gpus=args.gpus, n_jobs=N_JOBS)
gFun = BertGen(multilingualIndex, batch_size=4, nepochs=10, gpus=args.gpus, n_jobs=N_JOBS) # gFun = BertGen(multilingualIndex, batch_size=4, nepochs=10, gpus=args.gpus, n_jobs=N_JOBS)
gFun.fit(lX, ly) gFun.fit(lX, ly)

View File

@ -6,9 +6,9 @@ from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from transformers import AdamW from transformers import AdamW
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.metrics import F1, Accuracy from pytorch_lightning.metrics import Accuracy
from models.helpers import init_embeddings from models.helpers import init_embeddings
from util.pl_metrics import CustomF1 from util.pl_metrics import CustomF1, CustomK
from util.evaluation import evaluate from util.evaluation import evaluate
# TODO: it should also be possible to compute metrics independently for each language! # TODO: it should also be possible to compute metrics independently for each language!
@ -33,12 +33,10 @@ class RecurrentModel(pl.LightningModule):
self.loss = torch.nn.BCEWithLogitsLoss() self.loss = torch.nn.BCEWithLogitsLoss()
self.accuracy = Accuracy() self.accuracy = Accuracy()
self.microF1_tr = CustomF1(num_classes=output_size, average='micro', device=self.gpus) self.microF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
self.macroF1_tr = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
self.microF1_va = CustomF1(num_classes=output_size, average='micro', device=self.gpus) self.microK = CustomK(num_classes=output_size, average='micro', device=self.gpus)
self.macroF1_va = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.macroK = CustomK(num_classes=output_size, average='macro', device=self.gpus)
self.microF1_te = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
self.macroF1_te = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
self.lPretrained_embeddings = nn.ModuleDict() self.lPretrained_embeddings = nn.ModuleDict()
self.lLearnable_embeddings = nn.ModuleDict() self.lLearnable_embeddings = nn.ModuleDict()
@ -110,12 +108,16 @@ class RecurrentModel(pl.LightningModule):
# Squashing logits through Sigmoid in order to get confidence score # Squashing logits through Sigmoid in order to get confidence score
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
accuracy = self.accuracy(predictions, ly) accuracy = self.accuracy(predictions, ly)
microF1 = self.microF1_tr(predictions, ly) microF1 = self.microF1(predictions, ly)
macroF1 = self.macroF1_tr(predictions, ly) macroF1 = self.macroF1(predictions, ly)
microK = self.microK(predictions, ly)
macroK = self.macroK(predictions, ly)
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-macroK', macroK, on_step=True, on_epoch=True, prog_bar=False, logger=True)
self.log('train-microK', microK, on_step=True, on_epoch=True, prog_bar=False, logger=True)
return {'loss': loss} return {'loss': loss}
def validation_step(self, val_batch, batch_idx): def validation_step(self, val_batch, batch_idx):
@ -128,12 +130,16 @@ class RecurrentModel(pl.LightningModule):
loss = self.loss(logits, ly) loss = self.loss(logits, ly)
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
accuracy = self.accuracy(predictions, ly) accuracy = self.accuracy(predictions, ly)
microF1 = self.microF1_va(predictions, ly) microF1 = self.microF1(predictions, ly)
macroF1 = self.macroF1_va(predictions, ly) macroF1 = self.macroF1(predictions, ly)
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) microK = self.microK(predictions, ly)
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) macroK = self.macroK(predictions, ly)
self.log('val-loss', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
self.log('val-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val-macroK', macroK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return {'loss': loss} return {'loss': loss}
def test_step(self, test_batch, batch_idx): def test_step(self, test_batch, batch_idx):
@ -145,8 +151,8 @@ class RecurrentModel(pl.LightningModule):
ly = torch.cat(_ly, dim=0) ly = torch.cat(_ly, dim=0)
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
accuracy = self.accuracy(predictions, ly) accuracy = self.accuracy(predictions, ly)
microF1 = self.microF1_te(predictions, ly) microF1 = self.microF1(predictions, ly)
macroF1 = self.macroF1_te(predictions, ly) macroF1 = self.macroF1(predictions, ly)
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)

View File

@ -3,6 +3,21 @@ from pytorch_lightning.metrics import Metric
from util.common import is_false, is_true from util.common import is_false, is_true
def _update(pred, target, device):
assert pred.shape == target.shape
# preparing preds and targets for count
true_pred = is_true(pred, device)
false_pred = is_false(pred, device)
true_target = is_true(target, device)
false_target = is_false(target, device)
tp = torch.sum(true_pred * true_target, dim=0)
tn = torch.sum(false_pred * false_target, dim=0)
fp = torch.sum(true_pred * false_target, dim=0)
fn = torch.sum(false_pred * target, dim=0)
return tp, tn, fp, fn
class CustomF1(Metric): class CustomF1(Metric):
def __init__(self, num_classes, device, average='micro'): def __init__(self, num_classes, device, average='micro'):
""" """
@ -26,27 +41,13 @@ class CustomF1(Metric):
self.add_state('false_negative', default=torch.zeros(self.num_classes)) self.add_state('false_negative', default=torch.zeros(self.num_classes))
def update(self, preds, target): def update(self, preds, target):
true_positive, true_negative, false_positive, false_negative = self._update(preds, target) true_positive, true_negative, false_positive, false_negative = _update(preds, target, self.device)
self.true_positive += true_positive self.true_positive += true_positive
self.true_negative += true_negative self.true_negative += true_negative
self.false_positive += false_positive self.false_positive += false_positive
self.false_negative += false_negative self.false_negative += false_negative
def _update(self, pred, target):
assert pred.shape == target.shape
# preparing preds and targets for count
true_pred = is_true(pred, self.device)
false_pred = is_false(pred, self.device)
true_target = is_true(target, self.device)
false_target = is_false(target, self.device)
tp = torch.sum(true_pred * true_target, dim=0)
tn = torch.sum(false_pred * false_target, dim=0)
fp = torch.sum(true_pred * false_target, dim=0)
fn = torch.sum(false_pred * target, dim=0)
return tp, tn, fp, fn
def compute(self): def compute(self):
if self.average == 'micro': if self.average == 'micro':
num = 2.0 * self.true_positive.sum() num = 2.0 * self.true_positive.sum()
@ -69,3 +70,71 @@ class CustomF1(Metric):
class_specific.append(1.) class_specific.append(1.)
average = torch.sum(torch.Tensor(class_specific))/self.num_classes average = torch.sum(torch.Tensor(class_specific))/self.num_classes
return average.to(self.device) return average.to(self.device)
class CustomK(Metric):
def __init__(self, num_classes, device, average='micro'):
"""
K metric. https://dl.acm.org/doi/10.1145/2808194.2809449
:param num_classes:
:param device:
:param average:
"""
super().__init__()
self.num_classes = num_classes
self.average = average
self.device = 'cuda' if device else 'cpu'
self.add_state('true_positive', default=torch.zeros(self.num_classes))
self.add_state('true_negative', default=torch.zeros(self.num_classes))
self.add_state('false_positive', default=torch.zeros(self.num_classes))
self.add_state('false_negative', default=torch.zeros(self.num_classes))
def update(self, preds, target):
true_positive, true_negative, false_positive, false_negative = _update(preds, target, self.device)
self.true_positive += true_positive
self.true_negative += true_negative
self.false_positive += false_positive
self.false_negative += false_negative
def compute(self):
if self.average == 'micro':
specificity, recall = 0., 0.
absolute_negatives = self.true_negative.sum() + self.false_positive.sum()
if absolute_negatives != 0:
specificity = self.true_negative.sum()/absolute_negatives # Todo check if it is float
absolute_positives = self.true_positive.sum() + self.false_negative.sum()
if absolute_positives != 0:
recall = self.true_positive.sum()/absolute_positives # Todo check if it is float
if absolute_positives == 0:
return 2. * specificity - 1
elif absolute_negatives == 0:
return 2. * recall - 1
else:
return specificity + recall - 1
if self.average == 'macro':
class_specific = []
for i in range(self.num_classes):
class_tp = self.true_positive[i]
class_tn = self.true_negative[i]
class_fp = self.false_positive[i]
class_fn = self.false_negative[i]
specificity, recall = 0., 0.
absolute_negatives = class_tn + class_fp
if absolute_negatives != 0:
specificity = class_tn / absolute_negatives # Todo check if it is float
absolute_positives = class_tp + class_fn
if absolute_positives != 0:
recall = class_tp / absolute_positives # Todo check if it is float
if absolute_positives == 0:
class_specific.append(2. * specificity - 1)
elif absolute_negatives == 0:
class_specific.append(2. * recall - 1)
else:
class_specific.append(specificity + recall - 1)
average = torch.sum(torch.Tensor(class_specific)) / self.num_classes
return average.to(self.device)