fixed cuda oom at inference time

This commit is contained in:
andrea 2021-01-26 18:04:15 +01:00
parent 1a501949a1
commit 5cd36d27fc
2 changed files with 13 additions and 7 deletions

View File

@ -163,6 +163,11 @@ class BertModel(pl.LightningModule):
batch = pad(batch, pad_index=self.bert.config.pad_token_id, max_pad_length=max_pad_len) batch = pad(batch, pad_index=self.bert.config.pad_token_id, max_pad_length=max_pad_len)
batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu') batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu')
_, output = self.forward(batch) _, output = self.forward(batch)
# deleting batch from gpu to avoid cuda OOM
del batch
torch.cuda.empty_cache()
doc_embeds = output[-1][:, 0, :] doc_embeds = output[-1][:, 0, :]
l_embed[lang].append(doc_embeds.cpu()) l_embed[lang].append(doc_embeds.cpu())
for k, v in l_embed.items(): for k, v in l_embed.items():

View File

@ -16,7 +16,7 @@ This module contains the view generators that take care of computing the view sp
- View generator (-b): generates document embedding via mBERT model. - View generator (-b): generates document embedding via mBERT model.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from time import time # from time import time
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
@ -27,6 +27,7 @@ from src.models.pl_bert import BertModel
from src.models.pl_gru import RecurrentModel from src.models.pl_gru import RecurrentModel
from src.util.common import TfidfVectorizerMultilingual, _normalize from src.util.common import TfidfVectorizerMultilingual, _normalize
from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
# TODO: add early stop monitoring validation macroF1 + model checkpointing and loading from checkpoint
class ViewGen(ABC): class ViewGen(ABC):
@ -293,10 +294,10 @@ class RecurrentGen(ViewGen):
data = self.multilingualIndex.l_devel_index() data = self.multilingualIndex.l_devel_index()
self.model.to('cuda' if self.gpus else 'cpu') self.model.to('cuda' if self.gpus else 'cpu')
self.model.eval() self.model.eval()
time_init = time.time() # time_init = time.time()
l_embeds = self.model.encode(data, l_pad, batch_size=256) l_embeds = self.model.encode(data, l_pad, batch_size=256)
transform_time = round(time.time() - time_init, 3) # transform_time = round(time.time() - time_init, 3)
print(f'Executed! Transform took: {transform_time}') # print(f'Executed! Transform took: {transform_time}')
return l_embeds return l_embeds
def fit_transform(self, lX, ly): def fit_transform(self, lX, ly):
@ -362,10 +363,10 @@ class BertGen(ViewGen):
data = tokenize(data, max_len=512) data = tokenize(data, max_len=512)
self.model.to('cuda' if self.gpus else 'cpu') self.model.to('cuda' if self.gpus else 'cpu')
self.model.eval() self.model.eval()
time_init = time.time() # time_init = time.time()
l_emebds = self.model.encode(data, batch_size=64) l_emebds = self.model.encode(data, batch_size=64)
transform_time = round(time.time() - time_init, 3) # transform_time = round(time.time() - time_init, 3)
print(f'Executed! Transform took: {transform_time}') # print(f'Executed! Transform took: {transform_time}')
return l_emebds return l_emebds
def fit_transform(self, lX, ly): def fit_transform(self, lX, ly):