# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html import torch from torch import nn from transformers import AdamW import torch.nn.functional as F from torch.autograd import Variable import pytorch_lightning as pl from pytorch_lightning.metrics import Metric, F1, Accuracy from torch.optim.lr_scheduler import StepLR from models.helpers import init_embeddings from util.common import is_true, is_false from util.evaluation import evaluate class RecurrentModel(pl.LightningModule): """ Check out for logging insight https://www.learnopencv.com/tensorboard-with-pytorch-lightning/ """ def __init__(self, lPretrained, langs, output_size, hidden_size, lVocab_size, learnable_length, drop_embedding_range, drop_embedding_prop, gpus=None): super().__init__() self.gpus = gpus self.langs = langs self.lVocab_size = lVocab_size self.learnable_length = learnable_length self.output_size = output_size self.hidden_size = hidden_size self.drop_embedding_range = drop_embedding_range self.drop_embedding_prop = drop_embedding_prop self.loss = torch.nn.BCEWithLogitsLoss() # self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro') # self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro') self.accuracy = Accuracy() self.customMicroF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus) self.customMacroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.lPretrained_embeddings = nn.ModuleDict() self.lLearnable_embeddings = nn.ModuleDict() self.n_layers = 1 self.n_directions = 1 self.dropout = nn.Dropout(0.6) lstm_out = 256 ff1 = 512 ff2 = 256 lpretrained_embeddings = {} llearnable_embeddings = {} for lang in self.langs: pretrained = lPretrained[lang] if lPretrained else None pretrained_embeddings, learnable_embeddings, embedding_length = init_embeddings( pretrained, self.lVocab_size[lang], self.learnable_length) lpretrained_embeddings[lang] = pretrained_embeddings llearnable_embeddings[lang] = learnable_embeddings self.embedding_length = embedding_length self.lPretrained_embeddings.update(lpretrained_embeddings) self.lLearnable_embeddings.update(llearnable_embeddings) self.rnn = nn.GRU(self.embedding_length, hidden_size) self.linear0 = nn.Linear(hidden_size * self.n_directions, lstm_out) self.linear1 = nn.Linear(lstm_out, ff1) self.linear2 = nn.Linear(ff1, ff2) self.label = nn.Linear(ff2, self.output_size) lPretrained = None # TODO: setting lPretrained to None, letting it to its original value will bug first # validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow) self.save_hyperparameters() def forward(self, lX): _tmp = [] for lang in sorted(lX.keys()): doc_embedding = self.transform(lX[lang], lang) _tmp.append(doc_embedding) embed = torch.cat(_tmp, dim=0) logits = self.label(embed) return logits def transform(self, X, lang): batch_size = X.shape[0] X = self.embed(X, lang) X = self.embedding_dropout(X, drop_range=self.drop_embedding_range, p_drop=self.drop_embedding_prop, training=self.training) X = X.permute(1, 0, 2) h_0 = Variable(torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size).to(self.device)) output, _ = self.rnn(X, h_0) output = output[-1, :, :] output = F.relu(self.linear0(output)) output = self.dropout(F.relu(self.linear1(output))) output = self.dropout(F.relu(self.linear2(output))) return output def training_step(self, train_batch, batch_idx): lX, ly = train_batch logits = self.forward(lX) _ly = [] for lang in sorted(lX.keys()): _ly.append(ly[lang]) ly = torch.cat(_ly, dim=0) loss = self.loss(logits, ly) # Squashing logits through Sigmoid in order to get confidence score predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) microF1 = self.customMicroF1(predictions, ly) macroF1 = self.customMacroF1(predictions, ly) self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True) return {'loss': loss} def validation_step(self, val_batch, batch_idx): lX, ly = val_batch logits = self.forward(lX) _ly = [] for lang in sorted(lX.keys()): _ly.append(ly[lang]) ly = torch.cat(_ly, dim=0) loss = self.loss(logits, ly) predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) return {'loss': loss} def test_step(self, test_batch, batch_idx): lX, ly = test_batch logits = self.forward(lX) _ly = [] for lang in sorted(lX.keys()): _ly.append(ly[lang]) ly = torch.cat(_ly, dim=0) predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) return def embed(self, X, lang): input_list = [] if self.lPretrained_embeddings[lang]: input_list.append(self.lPretrained_embeddings[lang](X)) if self.lLearnable_embeddings[lang]: input_list.append(self.lLearnable_embeddings[lang](X)) return torch.cat(tensors=input_list, dim=2) def embedding_dropout(self, X, drop_range, p_drop=0.5, training=True): if p_drop > 0 and training and drop_range is not None: p = p_drop drop_from, drop_to = drop_range m = drop_to - drop_from # length of the supervised embedding l = X.shape[2] # total embedding length corr = (1 - p) X[:, :, drop_from:drop_to] = corr * F.dropout(X[:, :, drop_from:drop_to], p=p) X /= (1 - (p * m / l)) return X def configure_optimizers(self): optimizer = AdamW(self.parameters(), lr=1e-3) scheduler = StepLR(optimizer, step_size=25, gamma=0.5) return [optimizer], [scheduler] class CustomF1(Metric): def __init__(self, num_classes, device, average='micro'): """ Custom F1 metric. Scikit learn provides a full set of evaluation metrics, but they treat special cases differently. I.e., when the number of true positives, false positives, and false negatives amount to 0, all affected metrics (precision, recall, and thus f1) output 0 in Scikit learn. We adhere to the common practice of outputting 1 in this case since the classifier has correctly classified all examples as negatives. :param num_classes: :param device: :param average: """ super().__init__() self.num_classes = num_classes self.average = average self.device = 'cuda' if device else 'cpu' self.add_state('true_positive', default=torch.zeros(self.num_classes)) self.add_state('true_negative', default=torch.zeros(self.num_classes)) self.add_state('false_positive', default=torch.zeros(self.num_classes)) self.add_state('false_negative', default=torch.zeros(self.num_classes)) def update(self, preds, target): true_positive, true_negative, false_positive, false_negative = self._update(preds, target) self.true_positive += true_positive self.true_negative += true_negative self.false_positive += false_positive self.false_negative += false_negative def _update(self, pred, target): assert pred.shape == target.shape # preparing preds and targets for count true_pred = is_true(pred, self.device) false_pred = is_false(pred, self.device) true_target = is_true(target, self.device) false_target = is_false(target, self.device) tp = torch.sum(true_pred * true_target, dim=0) tn = torch.sum(false_pred * false_target, dim=0) fp = torch.sum(true_pred * false_target, dim=0) fn = torch.sum(false_pred * target, dim=0) return tp, tn, fp, fn def compute(self): if self.average == 'micro': num = 2.0 * self.true_positive.sum() den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum() if den > 0: return (num / den).to(self.device) return torch.FloatTensor([1.]).to(self.device) if self.average == 'macro': class_specific = [] for i in range(self.num_classes): class_tp = self.true_positive[i] # class_tn = self.true_negative[i] class_fp = self.false_positive[i] class_fn = self.false_negative[i] num = 2.0 * class_tp den = 2.0 * class_tp + class_fp + class_fn if den > 0: class_specific.append(num / den) else: class_specific.append(1.) average = torch.sum(torch.Tensor(class_specific))/self.num_classes return average.to(self.device)