diff --git a/refactor/main.py b/refactor/main.py index eb48cb1..2c88f7d 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -28,9 +28,9 @@ def main(args): # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS) - # gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128, - # nepochs=100, gpus=args.gpus, n_jobs=N_JOBS) - gFun = BertGen(multilingualIndex, batch_size=4, nepochs=10, gpus=args.gpus, n_jobs=N_JOBS) + gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=True, batch_size=128, + nepochs=100, gpus=args.gpus, n_jobs=N_JOBS) + # gFun = BertGen(multilingualIndex, batch_size=4, nepochs=10, gpus=args.gpus, n_jobs=N_JOBS) gFun.fit(lX, ly) diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index a0584f2..0fe5c6a 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -6,9 +6,9 @@ from torch.autograd import Variable from torch.optim.lr_scheduler import StepLR from transformers import AdamW import pytorch_lightning as pl -from pytorch_lightning.metrics import F1, Accuracy +from pytorch_lightning.metrics import Accuracy from models.helpers import init_embeddings -from util.pl_metrics import CustomF1 +from util.pl_metrics import CustomF1, CustomK from util.evaluation import evaluate # TODO: it should also be possible to compute metrics independently for each language! @@ -33,12 +33,10 @@ class RecurrentModel(pl.LightningModule): self.loss = torch.nn.BCEWithLogitsLoss() self.accuracy = Accuracy() - self.microF1_tr = CustomF1(num_classes=output_size, average='micro', device=self.gpus) - self.macroF1_tr = CustomF1(num_classes=output_size, average='macro', device=self.gpus) - self.microF1_va = CustomF1(num_classes=output_size, average='micro', device=self.gpus) - self.macroF1_va = CustomF1(num_classes=output_size, average='macro', device=self.gpus) - self.microF1_te = CustomF1(num_classes=output_size, average='micro', device=self.gpus) - self.macroF1_te = CustomF1(num_classes=output_size, average='macro', device=self.gpus) + self.microF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus) + self.macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus) + self.microK = CustomK(num_classes=output_size, average='micro', device=self.gpus) + self.macroK = CustomK(num_classes=output_size, average='macro', device=self.gpus) self.lPretrained_embeddings = nn.ModuleDict() self.lLearnable_embeddings = nn.ModuleDict() @@ -110,12 +108,16 @@ class RecurrentModel(pl.LightningModule): # Squashing logits through Sigmoid in order to get confidence score predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) - microF1 = self.microF1_tr(predictions, ly) - macroF1 = self.macroF1_tr(predictions, ly) + microF1 = self.microF1(predictions, ly) + macroF1 = self.macroF1(predictions, ly) + microK = self.microK(predictions, ly) + macroK = self.macroK(predictions, ly) self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True) + self.log('train-macroK', macroK, on_step=True, on_epoch=True, prog_bar=False, logger=True) + self.log('train-microK', microK, on_step=True, on_epoch=True, prog_bar=False, logger=True) return {'loss': loss} def validation_step(self, val_batch, batch_idx): @@ -128,12 +130,16 @@ class RecurrentModel(pl.LightningModule): loss = self.loss(logits, ly) predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) - microF1 = self.microF1_va(predictions, ly) - macroF1 = self.macroF1_va(predictions, ly) - self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) - self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) + microF1 = self.microF1(predictions, ly) + macroF1 = self.macroF1(predictions, ly) + microK = self.microK(predictions, ly) + macroK = self.macroK(predictions, ly) + self.log('val-loss', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.log('val-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val-macroK', macroK, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True) return {'loss': loss} def test_step(self, test_batch, batch_idx): @@ -145,8 +151,8 @@ class RecurrentModel(pl.LightningModule): ly = torch.cat(_ly, dim=0) predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) - microF1 = self.microF1_te(predictions, ly) - macroF1 = self.macroF1_te(predictions, ly) + microF1 = self.microF1(predictions, ly) + macroF1 = self.macroF1(predictions, ly) self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True) diff --git a/refactor/util/pl_metrics.py b/refactor/util/pl_metrics.py index a54bacb..6781d09 100644 --- a/refactor/util/pl_metrics.py +++ b/refactor/util/pl_metrics.py @@ -3,6 +3,21 @@ from pytorch_lightning.metrics import Metric from util.common import is_false, is_true +def _update(pred, target, device): + assert pred.shape == target.shape + # preparing preds and targets for count + true_pred = is_true(pred, device) + false_pred = is_false(pred, device) + true_target = is_true(target, device) + false_target = is_false(target, device) + + tp = torch.sum(true_pred * true_target, dim=0) + tn = torch.sum(false_pred * false_target, dim=0) + fp = torch.sum(true_pred * false_target, dim=0) + fn = torch.sum(false_pred * target, dim=0) + return tp, tn, fp, fn + + class CustomF1(Metric): def __init__(self, num_classes, device, average='micro'): """ @@ -26,27 +41,13 @@ class CustomF1(Metric): self.add_state('false_negative', default=torch.zeros(self.num_classes)) def update(self, preds, target): - true_positive, true_negative, false_positive, false_negative = self._update(preds, target) + true_positive, true_negative, false_positive, false_negative = _update(preds, target, self.device) self.true_positive += true_positive self.true_negative += true_negative self.false_positive += false_positive self.false_negative += false_negative - def _update(self, pred, target): - assert pred.shape == target.shape - # preparing preds and targets for count - true_pred = is_true(pred, self.device) - false_pred = is_false(pred, self.device) - true_target = is_true(target, self.device) - false_target = is_false(target, self.device) - - tp = torch.sum(true_pred * true_target, dim=0) - tn = torch.sum(false_pred * false_target, dim=0) - fp = torch.sum(true_pred * false_target, dim=0) - fn = torch.sum(false_pred * target, dim=0) - return tp, tn, fp, fn - def compute(self): if self.average == 'micro': num = 2.0 * self.true_positive.sum() @@ -69,3 +70,71 @@ class CustomF1(Metric): class_specific.append(1.) average = torch.sum(torch.Tensor(class_specific))/self.num_classes return average.to(self.device) + + +class CustomK(Metric): + def __init__(self, num_classes, device, average='micro'): + """ + K metric. https://dl.acm.org/doi/10.1145/2808194.2809449 + :param num_classes: + :param device: + :param average: + """ + super().__init__() + self.num_classes = num_classes + self.average = average + self.device = 'cuda' if device else 'cpu' + self.add_state('true_positive', default=torch.zeros(self.num_classes)) + self.add_state('true_negative', default=torch.zeros(self.num_classes)) + self.add_state('false_positive', default=torch.zeros(self.num_classes)) + self.add_state('false_negative', default=torch.zeros(self.num_classes)) + + def update(self, preds, target): + true_positive, true_negative, false_positive, false_negative = _update(preds, target, self.device) + + self.true_positive += true_positive + self.true_negative += true_negative + self.false_positive += false_positive + self.false_negative += false_negative + + def compute(self): + if self.average == 'micro': + specificity, recall = 0., 0. + absolute_negatives = self.true_negative.sum() + self.false_positive.sum() + if absolute_negatives != 0: + specificity = self.true_negative.sum()/absolute_negatives # Todo check if it is float + absolute_positives = self.true_positive.sum() + self.false_negative.sum() + if absolute_positives != 0: + recall = self.true_positive.sum()/absolute_positives # Todo check if it is float + + if absolute_positives == 0: + return 2. * specificity - 1 + elif absolute_negatives == 0: + return 2. * recall - 1 + else: + return specificity + recall - 1 + + if self.average == 'macro': + class_specific = [] + for i in range(self.num_classes): + class_tp = self.true_positive[i] + class_tn = self.true_negative[i] + class_fp = self.false_positive[i] + class_fn = self.false_negative[i] + + specificity, recall = 0., 0. + absolute_negatives = class_tn + class_fp + if absolute_negatives != 0: + specificity = class_tn / absolute_negatives # Todo check if it is float + absolute_positives = class_tp + class_fn + if absolute_positives != 0: + recall = class_tp / absolute_positives # Todo check if it is float + + if absolute_positives == 0: + class_specific.append(2. * specificity - 1) + elif absolute_negatives == 0: + class_specific.append(2. * recall - 1) + else: + class_specific.append(specificity + recall - 1) + average = torch.sum(torch.Tensor(class_specific)) / self.num_classes + return average.to(self.device)