fixed cuda oom at inference time
This commit is contained in:
parent
1a501949a1
commit
5cd36d27fc
|
|
@ -163,6 +163,11 @@ class BertModel(pl.LightningModule):
|
|||
batch = pad(batch, pad_index=self.bert.config.pad_token_id, max_pad_length=max_pad_len)
|
||||
batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu')
|
||||
_, output = self.forward(batch)
|
||||
|
||||
# deleting batch from gpu to avoid cuda OOM
|
||||
del batch
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
doc_embeds = output[-1][:, 0, :]
|
||||
l_embed[lang].append(doc_embeds.cpu())
|
||||
for k, v in l_embed.items():
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ This module contains the view generators that take care of computing the view sp
|
|||
- View generator (-b): generates document embedding via mBERT model.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
# from time import time
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
|
@ -27,6 +27,7 @@ from src.models.pl_bert import BertModel
|
|||
from src.models.pl_gru import RecurrentModel
|
||||
from src.util.common import TfidfVectorizerMultilingual, _normalize
|
||||
from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
||||
# TODO: add early stop monitoring validation macroF1 + model checkpointing and loading from checkpoint
|
||||
|
||||
|
||||
class ViewGen(ABC):
|
||||
|
|
@ -293,10 +294,10 @@ class RecurrentGen(ViewGen):
|
|||
data = self.multilingualIndex.l_devel_index()
|
||||
self.model.to('cuda' if self.gpus else 'cpu')
|
||||
self.model.eval()
|
||||
time_init = time.time()
|
||||
# time_init = time.time()
|
||||
l_embeds = self.model.encode(data, l_pad, batch_size=256)
|
||||
transform_time = round(time.time() - time_init, 3)
|
||||
print(f'Executed! Transform took: {transform_time}')
|
||||
# transform_time = round(time.time() - time_init, 3)
|
||||
# print(f'Executed! Transform took: {transform_time}')
|
||||
return l_embeds
|
||||
|
||||
def fit_transform(self, lX, ly):
|
||||
|
|
@ -362,10 +363,10 @@ class BertGen(ViewGen):
|
|||
data = tokenize(data, max_len=512)
|
||||
self.model.to('cuda' if self.gpus else 'cpu')
|
||||
self.model.eval()
|
||||
time_init = time.time()
|
||||
# time_init = time.time()
|
||||
l_emebds = self.model.encode(data, batch_size=64)
|
||||
transform_time = round(time.time() - time_init, 3)
|
||||
print(f'Executed! Transform took: {transform_time}')
|
||||
# transform_time = round(time.time() - time_init, 3)
|
||||
# print(f'Executed! Transform took: {transform_time}')
|
||||
return l_emebds
|
||||
|
||||
def fit_transform(self, lX, ly):
|
||||
|
|
|
|||
Loading…
Reference in New Issue