From a60e2cfc0952cdb99afcbcf00f196794536b0392 Mon Sep 17 00:00:00 2001 From: andrea Date: Wed, 20 Jan 2021 14:55:09 +0100 Subject: [PATCH] Implemented custom micro and macro F1 in pl (cpu and gpu) --- refactor/main.py | 4 +- refactor/models/pl_gru.py | 105 +++++++++--------------------------- refactor/util/pl_metrics.py | 71 ++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 83 deletions(-) create mode 100644 refactor/util/pl_metrics.py diff --git a/refactor/main.py b/refactor/main.py index e44adcd..8791d6d 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -28,8 +28,8 @@ def main(args): # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS) - gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, nepochs=100, - gpus=args.gpus, n_jobs=N_JOBS) + gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128, + nepochs=100, gpus=args.gpus, n_jobs=N_JOBS) # gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS) gFun.fit(lX, ly) diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index 1ed8314..690843d 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -5,10 +5,10 @@ from transformers import AdamW import torch.nn.functional as F from torch.autograd import Variable import pytorch_lightning as pl -from pytorch_lightning.metrics import Metric, F1, Accuracy +from pytorch_lightning.metrics import F1, Accuracy from torch.optim.lr_scheduler import StepLR from models.helpers import init_embeddings -from util.common import is_true, is_false +from util.pl_metrics import CustomF1 from util.evaluation import evaluate @@ -29,11 +29,14 @@ class RecurrentModel(pl.LightningModule): self.drop_embedding_range = drop_embedding_range self.drop_embedding_prop = drop_embedding_prop self.loss = torch.nn.BCEWithLogitsLoss() - # self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro') - # self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro') + self.accuracy = Accuracy() - self.customMicroF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus) - self.customMacroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus) + self.microF1_tr = CustomF1(num_classes=output_size, average='micro', device=self.gpus) + self.macroF1_tr = CustomF1(num_classes=output_size, average='macro', device=self.gpus) + self.microF1_va = CustomF1(num_classes=output_size, average='micro', device=self.gpus) + self.macroF1_va = CustomF1(num_classes=output_size, average='macro', device=self.gpus) + self.microF1_te = CustomF1(num_classes=output_size, average='micro', device=self.gpus) + self.macroF1_te = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.lPretrained_embeddings = nn.ModuleDict() self.lLearnable_embeddings = nn.ModuleDict() @@ -104,12 +107,12 @@ class RecurrentModel(pl.LightningModule): # Squashing logits through Sigmoid in order to get confidence score predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) - microF1 = self.customMicroF1(predictions, ly) - macroF1 = self.customMacroF1(predictions, ly) - self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) + microF1 = self.microF1_tr(predictions, ly) + macroF1 = self.macroF1_tr(predictions, ly) + self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) - self.log('microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True) + self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True) return {'loss': loss} def validation_step(self, val_batch, batch_idx): @@ -122,8 +125,12 @@ class RecurrentModel(pl.LightningModule): loss = self.loss(logits, ly) predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) - self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) + microF1 = self.microF1_va(predictions, ly) + macroF1 = self.macroF1_va(predictions, ly) + self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) + self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) return {'loss': loss} def test_step(self, test_batch, batch_idx): @@ -135,7 +142,11 @@ class RecurrentModel(pl.LightningModule): ly = torch.cat(_ly, dim=0) predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) - self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) + microF1 = self.microF1_te(predictions, ly) + macroF1 = self.macroF1_te(predictions, ly) + self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True) return def embed(self, X, lang): @@ -161,71 +172,3 @@ class RecurrentModel(pl.LightningModule): optimizer = AdamW(self.parameters(), lr=1e-3) scheduler = StepLR(optimizer, step_size=25, gamma=0.5) return [optimizer], [scheduler] - - -class CustomF1(Metric): - def __init__(self, num_classes, device, average='micro'): - """ - Custom F1 metric. - Scikit learn provides a full set of evaluation metrics, but they treat special cases differently. - I.e., when the number of true positives, false positives, and false negatives amount to 0, all - affected metrics (precision, recall, and thus f1) output 0 in Scikit learn. - We adhere to the common practice of outputting 1 in this case since the classifier has correctly - classified all examples as negatives. - :param num_classes: - :param device: - :param average: - """ - super().__init__() - self.num_classes = num_classes - self.average = average - self.device = 'cuda' if device else 'cpu' - self.add_state('true_positive', default=torch.zeros(self.num_classes)) - self.add_state('true_negative', default=torch.zeros(self.num_classes)) - self.add_state('false_positive', default=torch.zeros(self.num_classes)) - self.add_state('false_negative', default=torch.zeros(self.num_classes)) - - def update(self, preds, target): - true_positive, true_negative, false_positive, false_negative = self._update(preds, target) - - self.true_positive += true_positive - self.true_negative += true_negative - self.false_positive += false_positive - self.false_negative += false_negative - - def _update(self, pred, target): - assert pred.shape == target.shape - # preparing preds and targets for count - true_pred = is_true(pred, self.device) - false_pred = is_false(pred, self.device) - true_target = is_true(target, self.device) - false_target = is_false(target, self.device) - - tp = torch.sum(true_pred * true_target, dim=0) - tn = torch.sum(false_pred * false_target, dim=0) - fp = torch.sum(true_pred * false_target, dim=0) - fn = torch.sum(false_pred * target, dim=0) - return tp, tn, fp, fn - - def compute(self): - if self.average == 'micro': - num = 2.0 * self.true_positive.sum() - den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum() - if den > 0: - return (num / den).to(self.device) - return torch.FloatTensor([1.]).to(self.device) - if self.average == 'macro': - class_specific = [] - for i in range(self.num_classes): - class_tp = self.true_positive[i] - # class_tn = self.true_negative[i] - class_fp = self.false_positive[i] - class_fn = self.false_negative[i] - num = 2.0 * class_tp - den = 2.0 * class_tp + class_fp + class_fn - if den > 0: - class_specific.append(num / den) - else: - class_specific.append(1.) - average = torch.sum(torch.Tensor(class_specific))/self.num_classes - return average.to(self.device) diff --git a/refactor/util/pl_metrics.py b/refactor/util/pl_metrics.py new file mode 100644 index 0000000..a54bacb --- /dev/null +++ b/refactor/util/pl_metrics.py @@ -0,0 +1,71 @@ +import torch +from pytorch_lightning.metrics import Metric +from util.common import is_false, is_true + + +class CustomF1(Metric): + def __init__(self, num_classes, device, average='micro'): + """ + Custom F1 metric. + Scikit learn provides a full set of evaluation metrics, but they treat special cases differently. + I.e., when the number of true positives, false positives, and false negatives amount to 0, all + affected metrics (precision, recall, and thus f1) output 0 in Scikit learn. + We adhere to the common practice of outputting 1 in this case since the classifier has correctly + classified all examples as negatives. + :param num_classes: + :param device: + :param average: + """ + super().__init__() + self.num_classes = num_classes + self.average = average + self.device = 'cuda' if device else 'cpu' + self.add_state('true_positive', default=torch.zeros(self.num_classes)) + self.add_state('true_negative', default=torch.zeros(self.num_classes)) + self.add_state('false_positive', default=torch.zeros(self.num_classes)) + self.add_state('false_negative', default=torch.zeros(self.num_classes)) + + def update(self, preds, target): + true_positive, true_negative, false_positive, false_negative = self._update(preds, target) + + self.true_positive += true_positive + self.true_negative += true_negative + self.false_positive += false_positive + self.false_negative += false_negative + + def _update(self, pred, target): + assert pred.shape == target.shape + # preparing preds and targets for count + true_pred = is_true(pred, self.device) + false_pred = is_false(pred, self.device) + true_target = is_true(target, self.device) + false_target = is_false(target, self.device) + + tp = torch.sum(true_pred * true_target, dim=0) + tn = torch.sum(false_pred * false_target, dim=0) + fp = torch.sum(true_pred * false_target, dim=0) + fn = torch.sum(false_pred * target, dim=0) + return tp, tn, fp, fn + + def compute(self): + if self.average == 'micro': + num = 2.0 * self.true_positive.sum() + den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum() + if den > 0: + return (num / den).to(self.device) + return torch.FloatTensor([1.]).to(self.device) + if self.average == 'macro': + class_specific = [] + for i in range(self.num_classes): + class_tp = self.true_positive[i] + class_tn = self.true_negative[i] + class_fp = self.false_positive[i] + class_fn = self.false_negative[i] + num = 2.0 * class_tp + den = 2.0 * class_tp + class_fp + class_fn + if den > 0: + class_specific.append(num / den) + else: + class_specific.append(1.) + average = torch.sum(torch.Tensor(class_specific))/self.num_classes + return average.to(self.device)