Implemented funnelling architecture

This commit is contained in:
andrea 2021-01-25 17:20:17 +01:00
parent 8fa8ae5989
commit 93436fc596
6 changed files with 57 additions and 23 deletions

View File

@ -140,6 +140,22 @@ class RecurrentDataModule(pl.LightningDataModule):
collate_fn=self.test_dataset.collate_fn) collate_fn=self.test_dataset.collate_fn)
def tokenize(l_raw, max_len):
"""
run Bert tokenization on dict {lang: list of samples}.
:param l_raw:
:param max_len:
:return:
"""
# TODO: check BertTokenizerFast https://huggingface.co/transformers/model_doc/bert.html#berttokenizerfast
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
l_tokenized = {}
for lang in l_raw.keys():
output_tokenizer = tokenizer(l_raw[lang], truncation=True, max_length=max_len, padding='max_length')
l_tokenized[lang] = output_tokenizer['input_ids']
return l_tokenized
class BertDataModule(RecurrentDataModule): class BertDataModule(RecurrentDataModule):
def __init__(self, multilingualIndex, batchsize=64, max_len=512): def __init__(self, multilingualIndex, batchsize=64, max_len=512):
super().__init__(multilingualIndex, batchsize) super().__init__(multilingualIndex, batchsize)
@ -152,7 +168,7 @@ class BertDataModule(RecurrentDataModule):
l_train_raw = {l: train[:5] for l, train in l_train_raw.items()} l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
l_train_target = {l: target[:5] for l, target in l_train_target.items()} l_train_target = {l: target[:5] for l, target in l_train_target.items()}
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len) l_train_index = tokenize(l_train_raw, max_len=self.max_len)
self.training_dataset = RecurrentDataset(l_train_index, l_train_target, self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
@ -161,7 +177,7 @@ class BertDataModule(RecurrentDataModule):
l_val_raw = {l: train[:5] for l, train in l_val_raw.items()} l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
l_val_target = {l: target[:5] for l, target in l_val_target.items()} l_val_target = {l: target[:5] for l, target in l_val_target.items()}
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len) l_val_index = tokenize(l_val_raw, max_len=self.max_len)
self.val_dataset = RecurrentDataset(l_val_index, l_val_target, self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
@ -171,20 +187,10 @@ class BertDataModule(RecurrentDataModule):
l_test_raw = {l: train[:5] for l, train in l_test_raw.items()} l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
l_test_target = {l: target[:5] for l, target in l_test_target.items()} l_test_target = {l: target[:5] for l, target in l_test_target.items()}
l_test_index = self.tokenize(l_test_raw, max_len=self.max_len) l_test_index = tokenize(l_test_raw, max_len=self.max_len)
self.test_dataset = RecurrentDataset(l_test_index, l_test_target, self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
@staticmethod
def tokenize(l_raw, max_len):
# TODO: check BertTokenizerFast https://huggingface.co/transformers/model_doc/bert.html#berttokenizerfast
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
l_tokenized = {}
for lang in l_raw.keys():
output_tokenizer = tokenizer(l_raw[lang], truncation=True, max_length=max_len, padding='max_length')
l_tokenized[lang] = output_tokenizer['input_ids']
return l_tokenized
def train_dataloader(self): def train_dataloader(self):
""" """
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files" NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"

View File

@ -28,12 +28,14 @@ def main(args):
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary()) multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
# posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) # museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
wceEmbedder = WordClassGen(n_jobs=N_JOBS) # wceEmbedder = WordClassGen(n_jobs=N_JOBS)
# rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, # rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
# nepochs=250, gpus=args.gpus, n_jobs=N_JOBS) # nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
# bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS) bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
bertEmbedder.transform(lX)
exit()
docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder]) docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder])
gfun = Funnelling(first_tier=docEmbedders) gfun = Funnelling(first_tier=docEmbedders)

View File

@ -3,6 +3,7 @@ import pytorch_lightning as pl
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from transformers import BertForSequenceClassification, AdamW from transformers import BertForSequenceClassification, AdamW
from util.pl_metrics import CustomF1, CustomK from util.pl_metrics import CustomF1, CustomK
from util.common import define_pad_length, pad
class BertModel(pl.LightningModule): class BertModel(pl.LightningModule):
@ -70,7 +71,7 @@ class BertModel(pl.LightningModule):
langs = set(langs) langs = set(langs)
# outputs is a of n dicts of m elements, where n is equal to the number of epoch steps and m is batchsize. # outputs is a of n dicts of m elements, where n is equal to the number of epoch steps and m is batchsize.
# here we save epoch level metric values and compute them specifically for each language # here we save epoch level metric values and compute them specifically for each language
# TODO: this is horrible... # TODO: make this a function (reused in pl_gru epoch_end)
res_macroF1 = {lang: [] for lang in langs} res_macroF1 = {lang: [] for lang in langs}
res_microF1 = {lang: [] for lang in langs} res_microF1 = {lang: [] for lang in langs}
res_macroK = {lang: [] for lang in langs} res_macroK = {lang: [] for lang in langs}
@ -150,6 +151,25 @@ class BertModel(pl.LightningModule):
scheduler = StepLR(optimizer, step_size=25, gamma=0.1) scheduler = StepLR(optimizer, step_size=25, gamma=0.1)
return [optimizer], [scheduler] return [optimizer], [scheduler]
def encode(self, lX, batch_size=64):
with torch.no_grad():
l_embed = {lang: [] for lang in lX.keys()}
for lang in sorted(lX.keys()):
for i in range(0, len(lX[lang]), batch_size):
if i + batch_size > len(lX[lang]):
batch = lX[lang][i:len(lX[lang])]
else:
batch = lX[lang][i:i + batch_size]
max_pad_len = define_pad_length(batch)
batch = pad(batch, pad_index='101', max_pad_length=max_pad_len) # TODO: check pad index!
batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu')
_, output = self.forward(batch)
doc_embeds = output[-1][:, 0, :]
l_embed[lang].append(doc_embeds.cpu())
for k, v in l_embed.items():
l_embed[k] = torch.cat(v, dim=0).numpy()
return l_embed
@staticmethod @staticmethod
def _reconstruct_dict(predictions, y, batch_langs): def _reconstruct_dict(predictions, y, batch_langs):
reconstructed_x = {lang: [] for lang in set(batch_langs)} reconstructed_x = {lang: [] for lang in set(batch_langs)}

View File

@ -137,9 +137,9 @@ class RecurrentModel(pl.LightningModule):
output = output[-1, :, :] output = output[-1, :, :]
output = F.relu(self.linear0(output)) output = F.relu(self.linear0(output))
output = self.dropout(F.relu(self.linear1(output))) output = self.dropout(F.relu(self.linear1(output)))
l_embed[lang].append(output) l_embed[lang].append(output.cpu())
for k, v in l_embed.items(): for k, v in l_embed.items():
l_embed[k] = torch.cat(v, dim=0).cpu().numpy() l_embed[k] = torch.cat(v, dim=0).numpy()
return l_embed return l_embed
def training_step(self, train_batch, batch_idx): def training_step(self, train_batch, batch_idx):

View File

@ -164,6 +164,9 @@ class MultilingualIndex:
def l_test_raw_index(self): def l_test_raw_index(self):
return {l: index.test_raw for l, index in self.l_index.items()} return {l: index.test_raw for l, index in self.l_index.items()}
def l_devel_raw_index(self):
return {l: index.devel_raw for l, index in self.l_index.items()}
def l_val_target(self): def l_val_target(self):
return {l: index.val_target for l, index in self.l_index.items()} return {l: index.val_target for l, index in self.l_index.items()}
@ -197,6 +200,9 @@ class MultilingualIndex:
def l_test_raw(self): def l_test_raw(self):
return self.l_test_raw_index(), self.l_test_target() return self.l_test_raw_index(), self.l_test_target()
def l_devel_raw(self):
return self.l_devel_raw_index(), self.l_devel_target()
def get_l_pad_index(self): def get_l_pad_index(self):
return {l: index.get_pad_index() for l, index in self.l_index.items()} return {l: index.get_pad_index() for l, index in self.l_index.items()}

View File

@ -21,7 +21,7 @@ from util.common import TfidfVectorizerMultilingual, _normalize
from models.pl_gru import RecurrentModel from models.pl_gru import RecurrentModel
from models.pl_bert import BertModel from models.pl_bert import BertModel
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from data.datamodule import RecurrentDataModule, BertDataModule from data.datamodule import RecurrentDataModule, BertDataModule, tokenize
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from time import time from time import time
@ -271,14 +271,14 @@ class BertGen(ViewGen):
def transform(self, lX): def transform(self, lX):
# lX is raw text data. It has to be first indexed via Bert Tokenizer. # lX is raw text data. It has to be first indexed via Bert Tokenizer.
data = 'TOKENIZE THIS' data = self.multilingualIndex.l_devel_raw_index()
data = tokenize(data, max_len=512)
self.model.to('cuda' if self.gpus else 'cpu') self.model.to('cuda' if self.gpus else 'cpu')
self.model.eval() self.model.eval()
time_init = time() time_init = time()
l_emebds = self.model.encode(data) # TODO l_emebds = self.model.encode(data, batch_size=64)
transform_time = round(time() - time_init, 3) transform_time = round(time() - time_init, 3)
print(f'Executed! Transform took: {transform_time}') print(f'Executed! Transform took: {transform_time}')
exit('BERT VIEWGEN TRANSFORM NOT IMPLEMENTED!')
return l_emebds return l_emebds
def fit_transform(self, lX, ly): def fit_transform(self, lX, ly):