fixed cuda oom at inference time
This commit is contained in:
parent
1a501949a1
commit
5cd36d27fc
|
|
@ -163,6 +163,11 @@ class BertModel(pl.LightningModule):
|
||||||
batch = pad(batch, pad_index=self.bert.config.pad_token_id, max_pad_length=max_pad_len)
|
batch = pad(batch, pad_index=self.bert.config.pad_token_id, max_pad_length=max_pad_len)
|
||||||
batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu')
|
batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu')
|
||||||
_, output = self.forward(batch)
|
_, output = self.forward(batch)
|
||||||
|
|
||||||
|
# deleting batch from gpu to avoid cuda OOM
|
||||||
|
del batch
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
doc_embeds = output[-1][:, 0, :]
|
doc_embeds = output[-1][:, 0, :]
|
||||||
l_embed[lang].append(doc_embeds.cpu())
|
l_embed[lang].append(doc_embeds.cpu())
|
||||||
for k, v in l_embed.items():
|
for k, v in l_embed.items():
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ This module contains the view generators that take care of computing the view sp
|
||||||
- View generator (-b): generates document embedding via mBERT model.
|
- View generator (-b): generates document embedding via mBERT model.
|
||||||
"""
|
"""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from time import time
|
# from time import time
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
|
|
@ -27,6 +27,7 @@ from src.models.pl_bert import BertModel
|
||||||
from src.models.pl_gru import RecurrentModel
|
from src.models.pl_gru import RecurrentModel
|
||||||
from src.util.common import TfidfVectorizerMultilingual, _normalize
|
from src.util.common import TfidfVectorizerMultilingual, _normalize
|
||||||
from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
||||||
|
# TODO: add early stop monitoring validation macroF1 + model checkpointing and loading from checkpoint
|
||||||
|
|
||||||
|
|
||||||
class ViewGen(ABC):
|
class ViewGen(ABC):
|
||||||
|
|
@ -293,10 +294,10 @@ class RecurrentGen(ViewGen):
|
||||||
data = self.multilingualIndex.l_devel_index()
|
data = self.multilingualIndex.l_devel_index()
|
||||||
self.model.to('cuda' if self.gpus else 'cpu')
|
self.model.to('cuda' if self.gpus else 'cpu')
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
time_init = time.time()
|
# time_init = time.time()
|
||||||
l_embeds = self.model.encode(data, l_pad, batch_size=256)
|
l_embeds = self.model.encode(data, l_pad, batch_size=256)
|
||||||
transform_time = round(time.time() - time_init, 3)
|
# transform_time = round(time.time() - time_init, 3)
|
||||||
print(f'Executed! Transform took: {transform_time}')
|
# print(f'Executed! Transform took: {transform_time}')
|
||||||
return l_embeds
|
return l_embeds
|
||||||
|
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
|
|
@ -362,10 +363,10 @@ class BertGen(ViewGen):
|
||||||
data = tokenize(data, max_len=512)
|
data = tokenize(data, max_len=512)
|
||||||
self.model.to('cuda' if self.gpus else 'cpu')
|
self.model.to('cuda' if self.gpus else 'cpu')
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
time_init = time.time()
|
# time_init = time.time()
|
||||||
l_emebds = self.model.encode(data, batch_size=64)
|
l_emebds = self.model.encode(data, batch_size=64)
|
||||||
transform_time = round(time.time() - time_init, 3)
|
# transform_time = round(time.time() - time_init, 3)
|
||||||
print(f'Executed! Transform took: {transform_time}')
|
# print(f'Executed! Transform took: {transform_time}')
|
||||||
return l_emebds
|
return l_emebds
|
||||||
|
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue