Implemented custom micro and macro F1 in pl (cpu and gpu)
This commit is contained in:
parent
7c73aa2149
commit
a60e2cfc09
|
|
@ -28,8 +28,8 @@ def main(args):
|
||||||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, nepochs=100,
|
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128,
|
||||||
gpus=args.gpus, n_jobs=N_JOBS)
|
nepochs=100, gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
||||||
|
|
||||||
gFun.fit(lX, ly)
|
gFun.fit(lX, ly)
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,10 @@ from transformers import AdamW
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.metrics import Metric, F1, Accuracy
|
from pytorch_lightning.metrics import F1, Accuracy
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from models.helpers import init_embeddings
|
from models.helpers import init_embeddings
|
||||||
from util.common import is_true, is_false
|
from util.pl_metrics import CustomF1
|
||||||
from util.evaluation import evaluate
|
from util.evaluation import evaluate
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -29,11 +29,14 @@ class RecurrentModel(pl.LightningModule):
|
||||||
self.drop_embedding_range = drop_embedding_range
|
self.drop_embedding_range = drop_embedding_range
|
||||||
self.drop_embedding_prop = drop_embedding_prop
|
self.drop_embedding_prop = drop_embedding_prop
|
||||||
self.loss = torch.nn.BCEWithLogitsLoss()
|
self.loss = torch.nn.BCEWithLogitsLoss()
|
||||||
# self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro')
|
|
||||||
# self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro')
|
|
||||||
self.accuracy = Accuracy()
|
self.accuracy = Accuracy()
|
||||||
self.customMicroF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
self.microF1_tr = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
self.customMacroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
self.macroF1_tr = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
self.microF1_va = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
|
self.macroF1_va = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
self.microF1_te = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
|
self.macroF1_te = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
|
||||||
self.lPretrained_embeddings = nn.ModuleDict()
|
self.lPretrained_embeddings = nn.ModuleDict()
|
||||||
self.lLearnable_embeddings = nn.ModuleDict()
|
self.lLearnable_embeddings = nn.ModuleDict()
|
||||||
|
|
@ -104,12 +107,12 @@ class RecurrentModel(pl.LightningModule):
|
||||||
# Squashing logits through Sigmoid in order to get confidence score
|
# Squashing logits through Sigmoid in order to get confidence score
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
accuracy = self.accuracy(predictions, ly)
|
accuracy = self.accuracy(predictions, ly)
|
||||||
microF1 = self.customMicroF1(predictions, ly)
|
microF1 = self.microF1_tr(predictions, ly)
|
||||||
macroF1 = self.customMacroF1(predictions, ly)
|
macroF1 = self.macroF1_tr(predictions, ly)
|
||||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, val_batch, batch_idx):
|
||||||
|
|
@ -122,8 +125,12 @@ class RecurrentModel(pl.LightningModule):
|
||||||
loss = self.loss(logits, ly)
|
loss = self.loss(logits, ly)
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
accuracy = self.accuracy(predictions, ly)
|
accuracy = self.accuracy(predictions, ly)
|
||||||
|
microF1 = self.microF1_va(predictions, ly)
|
||||||
|
macroF1 = self.macroF1_va(predictions, ly)
|
||||||
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
|
self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
def test_step(self, test_batch, batch_idx):
|
def test_step(self, test_batch, batch_idx):
|
||||||
|
|
@ -135,7 +142,11 @@ class RecurrentModel(pl.LightningModule):
|
||||||
ly = torch.cat(_ly, dim=0)
|
ly = torch.cat(_ly, dim=0)
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
accuracy = self.accuracy(predictions, ly)
|
accuracy = self.accuracy(predictions, ly)
|
||||||
|
microF1 = self.microF1_te(predictions, ly)
|
||||||
|
macroF1 = self.macroF1_te(predictions, ly)
|
||||||
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
|
self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
|
self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
def embed(self, X, lang):
|
def embed(self, X, lang):
|
||||||
|
|
@ -161,71 +172,3 @@ class RecurrentModel(pl.LightningModule):
|
||||||
optimizer = AdamW(self.parameters(), lr=1e-3)
|
optimizer = AdamW(self.parameters(), lr=1e-3)
|
||||||
scheduler = StepLR(optimizer, step_size=25, gamma=0.5)
|
scheduler = StepLR(optimizer, step_size=25, gamma=0.5)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
|
||||||
class CustomF1(Metric):
|
|
||||||
def __init__(self, num_classes, device, average='micro'):
|
|
||||||
"""
|
|
||||||
Custom F1 metric.
|
|
||||||
Scikit learn provides a full set of evaluation metrics, but they treat special cases differently.
|
|
||||||
I.e., when the number of true positives, false positives, and false negatives amount to 0, all
|
|
||||||
affected metrics (precision, recall, and thus f1) output 0 in Scikit learn.
|
|
||||||
We adhere to the common practice of outputting 1 in this case since the classifier has correctly
|
|
||||||
classified all examples as negatives.
|
|
||||||
:param num_classes:
|
|
||||||
:param device:
|
|
||||||
:param average:
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.average = average
|
|
||||||
self.device = 'cuda' if device else 'cpu'
|
|
||||||
self.add_state('true_positive', default=torch.zeros(self.num_classes))
|
|
||||||
self.add_state('true_negative', default=torch.zeros(self.num_classes))
|
|
||||||
self.add_state('false_positive', default=torch.zeros(self.num_classes))
|
|
||||||
self.add_state('false_negative', default=torch.zeros(self.num_classes))
|
|
||||||
|
|
||||||
def update(self, preds, target):
|
|
||||||
true_positive, true_negative, false_positive, false_negative = self._update(preds, target)
|
|
||||||
|
|
||||||
self.true_positive += true_positive
|
|
||||||
self.true_negative += true_negative
|
|
||||||
self.false_positive += false_positive
|
|
||||||
self.false_negative += false_negative
|
|
||||||
|
|
||||||
def _update(self, pred, target):
|
|
||||||
assert pred.shape == target.shape
|
|
||||||
# preparing preds and targets for count
|
|
||||||
true_pred = is_true(pred, self.device)
|
|
||||||
false_pred = is_false(pred, self.device)
|
|
||||||
true_target = is_true(target, self.device)
|
|
||||||
false_target = is_false(target, self.device)
|
|
||||||
|
|
||||||
tp = torch.sum(true_pred * true_target, dim=0)
|
|
||||||
tn = torch.sum(false_pred * false_target, dim=0)
|
|
||||||
fp = torch.sum(true_pred * false_target, dim=0)
|
|
||||||
fn = torch.sum(false_pred * target, dim=0)
|
|
||||||
return tp, tn, fp, fn
|
|
||||||
|
|
||||||
def compute(self):
|
|
||||||
if self.average == 'micro':
|
|
||||||
num = 2.0 * self.true_positive.sum()
|
|
||||||
den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum()
|
|
||||||
if den > 0:
|
|
||||||
return (num / den).to(self.device)
|
|
||||||
return torch.FloatTensor([1.]).to(self.device)
|
|
||||||
if self.average == 'macro':
|
|
||||||
class_specific = []
|
|
||||||
for i in range(self.num_classes):
|
|
||||||
class_tp = self.true_positive[i]
|
|
||||||
# class_tn = self.true_negative[i]
|
|
||||||
class_fp = self.false_positive[i]
|
|
||||||
class_fn = self.false_negative[i]
|
|
||||||
num = 2.0 * class_tp
|
|
||||||
den = 2.0 * class_tp + class_fp + class_fn
|
|
||||||
if den > 0:
|
|
||||||
class_specific.append(num / den)
|
|
||||||
else:
|
|
||||||
class_specific.append(1.)
|
|
||||||
average = torch.sum(torch.Tensor(class_specific))/self.num_classes
|
|
||||||
return average.to(self.device)
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
import torch
|
||||||
|
from pytorch_lightning.metrics import Metric
|
||||||
|
from util.common import is_false, is_true
|
||||||
|
|
||||||
|
|
||||||
|
class CustomF1(Metric):
|
||||||
|
def __init__(self, num_classes, device, average='micro'):
|
||||||
|
"""
|
||||||
|
Custom F1 metric.
|
||||||
|
Scikit learn provides a full set of evaluation metrics, but they treat special cases differently.
|
||||||
|
I.e., when the number of true positives, false positives, and false negatives amount to 0, all
|
||||||
|
affected metrics (precision, recall, and thus f1) output 0 in Scikit learn.
|
||||||
|
We adhere to the common practice of outputting 1 in this case since the classifier has correctly
|
||||||
|
classified all examples as negatives.
|
||||||
|
:param num_classes:
|
||||||
|
:param device:
|
||||||
|
:param average:
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.average = average
|
||||||
|
self.device = 'cuda' if device else 'cpu'
|
||||||
|
self.add_state('true_positive', default=torch.zeros(self.num_classes))
|
||||||
|
self.add_state('true_negative', default=torch.zeros(self.num_classes))
|
||||||
|
self.add_state('false_positive', default=torch.zeros(self.num_classes))
|
||||||
|
self.add_state('false_negative', default=torch.zeros(self.num_classes))
|
||||||
|
|
||||||
|
def update(self, preds, target):
|
||||||
|
true_positive, true_negative, false_positive, false_negative = self._update(preds, target)
|
||||||
|
|
||||||
|
self.true_positive += true_positive
|
||||||
|
self.true_negative += true_negative
|
||||||
|
self.false_positive += false_positive
|
||||||
|
self.false_negative += false_negative
|
||||||
|
|
||||||
|
def _update(self, pred, target):
|
||||||
|
assert pred.shape == target.shape
|
||||||
|
# preparing preds and targets for count
|
||||||
|
true_pred = is_true(pred, self.device)
|
||||||
|
false_pred = is_false(pred, self.device)
|
||||||
|
true_target = is_true(target, self.device)
|
||||||
|
false_target = is_false(target, self.device)
|
||||||
|
|
||||||
|
tp = torch.sum(true_pred * true_target, dim=0)
|
||||||
|
tn = torch.sum(false_pred * false_target, dim=0)
|
||||||
|
fp = torch.sum(true_pred * false_target, dim=0)
|
||||||
|
fn = torch.sum(false_pred * target, dim=0)
|
||||||
|
return tp, tn, fp, fn
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
if self.average == 'micro':
|
||||||
|
num = 2.0 * self.true_positive.sum()
|
||||||
|
den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum()
|
||||||
|
if den > 0:
|
||||||
|
return (num / den).to(self.device)
|
||||||
|
return torch.FloatTensor([1.]).to(self.device)
|
||||||
|
if self.average == 'macro':
|
||||||
|
class_specific = []
|
||||||
|
for i in range(self.num_classes):
|
||||||
|
class_tp = self.true_positive[i]
|
||||||
|
class_tn = self.true_negative[i]
|
||||||
|
class_fp = self.false_positive[i]
|
||||||
|
class_fn = self.false_negative[i]
|
||||||
|
num = 2.0 * class_tp
|
||||||
|
den = 2.0 * class_tp + class_fp + class_fn
|
||||||
|
if den > 0:
|
||||||
|
class_specific.append(num / den)
|
||||||
|
else:
|
||||||
|
class_specific.append(1.)
|
||||||
|
average = torch.sum(torch.Tensor(class_specific))/self.num_classes
|
||||||
|
return average.to(self.device)
|
||||||
Loading…
Reference in New Issue