diff --git a/src/models/pl_bert.py b/src/models/pl_bert.py index a9b669f..129c3b4 100644 --- a/src/models/pl_bert.py +++ b/src/models/pl_bert.py @@ -163,6 +163,11 @@ class BertModel(pl.LightningModule): batch = pad(batch, pad_index=self.bert.config.pad_token_id, max_pad_length=max_pad_len) batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu') _, output = self.forward(batch) + + # deleting batch from gpu to avoid cuda OOM + del batch + torch.cuda.empty_cache() + doc_embeds = output[-1][:, 0, :] l_embed[lang].append(doc_embeds.cpu()) for k, v in l_embed.items(): diff --git a/src/view_generators.py b/src/view_generators.py index b0f70bf..452714c 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -16,7 +16,7 @@ This module contains the view generators that take care of computing the view sp - View generator (-b): generates document embedding via mBERT model. """ from abc import ABC, abstractmethod -from time import time +# from time import time from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger @@ -27,6 +27,7 @@ from src.models.pl_bert import BertModel from src.models.pl_gru import RecurrentModel from src.util.common import TfidfVectorizerMultilingual, _normalize from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix +# TODO: add early stop monitoring validation macroF1 + model checkpointing and loading from checkpoint class ViewGen(ABC): @@ -293,10 +294,10 @@ class RecurrentGen(ViewGen): data = self.multilingualIndex.l_devel_index() self.model.to('cuda' if self.gpus else 'cpu') self.model.eval() - time_init = time.time() + # time_init = time.time() l_embeds = self.model.encode(data, l_pad, batch_size=256) - transform_time = round(time.time() - time_init, 3) - print(f'Executed! Transform took: {transform_time}') + # transform_time = round(time.time() - time_init, 3) + # print(f'Executed! Transform took: {transform_time}') return l_embeds def fit_transform(self, lX, ly): @@ -362,10 +363,10 @@ class BertGen(ViewGen): data = tokenize(data, max_len=512) self.model.to('cuda' if self.gpus else 'cpu') self.model.eval() - time_init = time.time() + # time_init = time.time() l_emebds = self.model.encode(data, batch_size=64) - transform_time = round(time.time() - time_init, 3) - print(f'Executed! Transform took: {transform_time}') + # transform_time = round(time.time() - time_init, 3) + # print(f'Executed! Transform took: {transform_time}') return l_emebds def fit_transform(self, lX, ly):