From 93436fc596507470a83da1197564f94ffef52c71 Mon Sep 17 00:00:00 2001 From: andrea Date: Mon, 25 Jan 2021 17:20:17 +0100 Subject: [PATCH] Implemented funnelling architecture --- refactor/data/datamodule.py | 32 +++++++++++++++++++------------- refactor/main.py | 8 +++++--- refactor/models/pl_bert.py | 22 +++++++++++++++++++++- refactor/models/pl_gru.py | 4 ++-- refactor/util/common.py | 6 ++++++ refactor/view_generators.py | 8 ++++---- 6 files changed, 57 insertions(+), 23 deletions(-) diff --git a/refactor/data/datamodule.py b/refactor/data/datamodule.py index 7329f08..711d5a3 100644 --- a/refactor/data/datamodule.py +++ b/refactor/data/datamodule.py @@ -140,6 +140,22 @@ class RecurrentDataModule(pl.LightningDataModule): collate_fn=self.test_dataset.collate_fn) +def tokenize(l_raw, max_len): + """ + run Bert tokenization on dict {lang: list of samples}. + :param l_raw: + :param max_len: + :return: + """ + # TODO: check BertTokenizerFast https://huggingface.co/transformers/model_doc/bert.html#berttokenizerfast + tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased') + l_tokenized = {} + for lang in l_raw.keys(): + output_tokenizer = tokenizer(l_raw[lang], truncation=True, max_length=max_len, padding='max_length') + l_tokenized[lang] = output_tokenizer['input_ids'] + return l_tokenized + + class BertDataModule(RecurrentDataModule): def __init__(self, multilingualIndex, batchsize=64, max_len=512): super().__init__(multilingualIndex, batchsize) @@ -152,7 +168,7 @@ class BertDataModule(RecurrentDataModule): l_train_raw = {l: train[:5] for l, train in l_train_raw.items()} l_train_target = {l: target[:5] for l, target in l_train_target.items()} - l_train_index = self.tokenize(l_train_raw, max_len=self.max_len) + l_train_index = tokenize(l_train_raw, max_len=self.max_len) self.training_dataset = RecurrentDataset(l_train_index, l_train_target, lPad_index=self.multilingualIndex.l_pad()) @@ -161,7 +177,7 @@ class BertDataModule(RecurrentDataModule): l_val_raw = {l: train[:5] for l, train in l_val_raw.items()} l_val_target = {l: target[:5] for l, target in l_val_target.items()} - l_val_index = self.tokenize(l_val_raw, max_len=self.max_len) + l_val_index = tokenize(l_val_raw, max_len=self.max_len) self.val_dataset = RecurrentDataset(l_val_index, l_val_target, lPad_index=self.multilingualIndex.l_pad()) @@ -171,20 +187,10 @@ class BertDataModule(RecurrentDataModule): l_test_raw = {l: train[:5] for l, train in l_test_raw.items()} l_test_target = {l: target[:5] for l, target in l_test_target.items()} - l_test_index = self.tokenize(l_test_raw, max_len=self.max_len) + l_test_index = tokenize(l_test_raw, max_len=self.max_len) self.test_dataset = RecurrentDataset(l_test_index, l_test_target, lPad_index=self.multilingualIndex.l_pad()) - @staticmethod - def tokenize(l_raw, max_len): - # TODO: check BertTokenizerFast https://huggingface.co/transformers/model_doc/bert.html#berttokenizerfast - tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased') - l_tokenized = {} - for lang in l_raw.keys(): - output_tokenizer = tokenizer(l_raw[lang], truncation=True, max_length=max_len, padding='max_length') - l_tokenized[lang] = output_tokenizer['input_ids'] - return l_tokenized - def train_dataloader(self): """ NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files" diff --git a/refactor/main.py b/refactor/main.py index d2ab71b..17f5a95 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -28,12 +28,14 @@ def main(args): multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary()) # posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) - museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) - wceEmbedder = WordClassGen(n_jobs=N_JOBS) + # museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) + # wceEmbedder = WordClassGen(n_jobs=N_JOBS) # rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, # nepochs=250, gpus=args.gpus, n_jobs=N_JOBS) - # bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS) + bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS) + bertEmbedder.transform(lX) + exit() docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder]) gfun = Funnelling(first_tier=docEmbedders) diff --git a/refactor/models/pl_bert.py b/refactor/models/pl_bert.py index c19f455..11fe0ce 100644 --- a/refactor/models/pl_bert.py +++ b/refactor/models/pl_bert.py @@ -3,6 +3,7 @@ import pytorch_lightning as pl from torch.optim.lr_scheduler import StepLR from transformers import BertForSequenceClassification, AdamW from util.pl_metrics import CustomF1, CustomK +from util.common import define_pad_length, pad class BertModel(pl.LightningModule): @@ -70,7 +71,7 @@ class BertModel(pl.LightningModule): langs = set(langs) # outputs is a of n dicts of m elements, where n is equal to the number of epoch steps and m is batchsize. # here we save epoch level metric values and compute them specifically for each language - # TODO: this is horrible... + # TODO: make this a function (reused in pl_gru epoch_end) res_macroF1 = {lang: [] for lang in langs} res_microF1 = {lang: [] for lang in langs} res_macroK = {lang: [] for lang in langs} @@ -150,6 +151,25 @@ class BertModel(pl.LightningModule): scheduler = StepLR(optimizer, step_size=25, gamma=0.1) return [optimizer], [scheduler] + def encode(self, lX, batch_size=64): + with torch.no_grad(): + l_embed = {lang: [] for lang in lX.keys()} + for lang in sorted(lX.keys()): + for i in range(0, len(lX[lang]), batch_size): + if i + batch_size > len(lX[lang]): + batch = lX[lang][i:len(lX[lang])] + else: + batch = lX[lang][i:i + batch_size] + max_pad_len = define_pad_length(batch) + batch = pad(batch, pad_index='101', max_pad_length=max_pad_len) # TODO: check pad index! + batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu') + _, output = self.forward(batch) + doc_embeds = output[-1][:, 0, :] + l_embed[lang].append(doc_embeds.cpu()) + for k, v in l_embed.items(): + l_embed[k] = torch.cat(v, dim=0).numpy() + return l_embed + @staticmethod def _reconstruct_dict(predictions, y, batch_langs): reconstructed_x = {lang: [] for lang in set(batch_langs)} diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index a13990c..ca4f8da 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -137,9 +137,9 @@ class RecurrentModel(pl.LightningModule): output = output[-1, :, :] output = F.relu(self.linear0(output)) output = self.dropout(F.relu(self.linear1(output))) - l_embed[lang].append(output) + l_embed[lang].append(output.cpu()) for k, v in l_embed.items(): - l_embed[k] = torch.cat(v, dim=0).cpu().numpy() + l_embed[k] = torch.cat(v, dim=0).numpy() return l_embed def training_step(self, train_batch, batch_idx): diff --git a/refactor/util/common.py b/refactor/util/common.py index 1b84d60..575570a 100644 --- a/refactor/util/common.py +++ b/refactor/util/common.py @@ -164,6 +164,9 @@ class MultilingualIndex: def l_test_raw_index(self): return {l: index.test_raw for l, index in self.l_index.items()} + def l_devel_raw_index(self): + return {l: index.devel_raw for l, index in self.l_index.items()} + def l_val_target(self): return {l: index.val_target for l, index in self.l_index.items()} @@ -197,6 +200,9 @@ class MultilingualIndex: def l_test_raw(self): return self.l_test_raw_index(), self.l_test_target() + def l_devel_raw(self): + return self.l_devel_raw_index(), self.l_devel_target() + def get_l_pad_index(self): return {l: index.get_pad_index() for l, index in self.l_index.items()} diff --git a/refactor/view_generators.py b/refactor/view_generators.py index 3b3d811..ca4ff93 100644 --- a/refactor/view_generators.py +++ b/refactor/view_generators.py @@ -21,7 +21,7 @@ from util.common import TfidfVectorizerMultilingual, _normalize from models.pl_gru import RecurrentModel from models.pl_bert import BertModel from pytorch_lightning import Trainer -from data.datamodule import RecurrentDataModule, BertDataModule +from data.datamodule import RecurrentDataModule, BertDataModule, tokenize from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger from time import time @@ -271,14 +271,14 @@ class BertGen(ViewGen): def transform(self, lX): # lX is raw text data. It has to be first indexed via Bert Tokenizer. - data = 'TOKENIZE THIS' + data = self.multilingualIndex.l_devel_raw_index() + data = tokenize(data, max_len=512) self.model.to('cuda' if self.gpus else 'cpu') self.model.eval() time_init = time() - l_emebds = self.model.encode(data) # TODO + l_emebds = self.model.encode(data, batch_size=64) transform_time = round(time() - time_init, 3) print(f'Executed! Transform took: {transform_time}') - exit('BERT VIEWGEN TRANSFORM NOT IMPLEMENTED!') return l_emebds def fit_transform(self, lX, ly):