Implemented funnelling architecture
This commit is contained in:
parent
8fa8ae5989
commit
93436fc596
|
|
@ -140,6 +140,22 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
collate_fn=self.test_dataset.collate_fn)
|
collate_fn=self.test_dataset.collate_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(l_raw, max_len):
|
||||||
|
"""
|
||||||
|
run Bert tokenization on dict {lang: list of samples}.
|
||||||
|
:param l_raw:
|
||||||
|
:param max_len:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# TODO: check BertTokenizerFast https://huggingface.co/transformers/model_doc/bert.html#berttokenizerfast
|
||||||
|
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
|
||||||
|
l_tokenized = {}
|
||||||
|
for lang in l_raw.keys():
|
||||||
|
output_tokenizer = tokenizer(l_raw[lang], truncation=True, max_length=max_len, padding='max_length')
|
||||||
|
l_tokenized[lang] = output_tokenizer['input_ids']
|
||||||
|
return l_tokenized
|
||||||
|
|
||||||
|
|
||||||
class BertDataModule(RecurrentDataModule):
|
class BertDataModule(RecurrentDataModule):
|
||||||
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
|
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
|
||||||
super().__init__(multilingualIndex, batchsize)
|
super().__init__(multilingualIndex, batchsize)
|
||||||
|
|
@ -152,7 +168,7 @@ class BertDataModule(RecurrentDataModule):
|
||||||
l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
||||||
l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||||
|
|
||||||
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
|
l_train_index = tokenize(l_train_raw, max_len=self.max_len)
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
|
|
@ -161,7 +177,7 @@ class BertDataModule(RecurrentDataModule):
|
||||||
l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
||||||
l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||||
|
|
||||||
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
l_val_index = tokenize(l_val_raw, max_len=self.max_len)
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
|
|
@ -171,20 +187,10 @@ class BertDataModule(RecurrentDataModule):
|
||||||
l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
||||||
l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||||
|
|
||||||
l_test_index = self.tokenize(l_test_raw, max_len=self.max_len)
|
l_test_index = tokenize(l_test_raw, max_len=self.max_len)
|
||||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tokenize(l_raw, max_len):
|
|
||||||
# TODO: check BertTokenizerFast https://huggingface.co/transformers/model_doc/bert.html#berttokenizerfast
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
|
|
||||||
l_tokenized = {}
|
|
||||||
for lang in l_raw.keys():
|
|
||||||
output_tokenizer = tokenizer(l_raw[lang], truncation=True, max_length=max_len, padding='max_length')
|
|
||||||
l_tokenized[lang] = output_tokenizer['input_ids']
|
|
||||||
return l_tokenized
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
"""
|
"""
|
||||||
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"
|
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"
|
||||||
|
|
|
||||||
|
|
@ -28,12 +28,14 @@ def main(args):
|
||||||
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
||||||
|
|
||||||
# posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
# posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||||
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
|
# museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
|
||||||
wceEmbedder = WordClassGen(n_jobs=N_JOBS)
|
# wceEmbedder = WordClassGen(n_jobs=N_JOBS)
|
||||||
# rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
|
# rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
|
||||||
# nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
|
# nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
# bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
|
bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
|
bertEmbedder.transform(lX)
|
||||||
|
|
||||||
|
exit()
|
||||||
docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder])
|
docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder])
|
||||||
|
|
||||||
gfun = Funnelling(first_tier=docEmbedders)
|
gfun = Funnelling(first_tier=docEmbedders)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import pytorch_lightning as pl
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from transformers import BertForSequenceClassification, AdamW
|
from transformers import BertForSequenceClassification, AdamW
|
||||||
from util.pl_metrics import CustomF1, CustomK
|
from util.pl_metrics import CustomF1, CustomK
|
||||||
|
from util.common import define_pad_length, pad
|
||||||
|
|
||||||
|
|
||||||
class BertModel(pl.LightningModule):
|
class BertModel(pl.LightningModule):
|
||||||
|
|
@ -70,7 +71,7 @@ class BertModel(pl.LightningModule):
|
||||||
langs = set(langs)
|
langs = set(langs)
|
||||||
# outputs is a of n dicts of m elements, where n is equal to the number of epoch steps and m is batchsize.
|
# outputs is a of n dicts of m elements, where n is equal to the number of epoch steps and m is batchsize.
|
||||||
# here we save epoch level metric values and compute them specifically for each language
|
# here we save epoch level metric values and compute them specifically for each language
|
||||||
# TODO: this is horrible...
|
# TODO: make this a function (reused in pl_gru epoch_end)
|
||||||
res_macroF1 = {lang: [] for lang in langs}
|
res_macroF1 = {lang: [] for lang in langs}
|
||||||
res_microF1 = {lang: [] for lang in langs}
|
res_microF1 = {lang: [] for lang in langs}
|
||||||
res_macroK = {lang: [] for lang in langs}
|
res_macroK = {lang: [] for lang in langs}
|
||||||
|
|
@ -150,6 +151,25 @@ class BertModel(pl.LightningModule):
|
||||||
scheduler = StepLR(optimizer, step_size=25, gamma=0.1)
|
scheduler = StepLR(optimizer, step_size=25, gamma=0.1)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
def encode(self, lX, batch_size=64):
|
||||||
|
with torch.no_grad():
|
||||||
|
l_embed = {lang: [] for lang in lX.keys()}
|
||||||
|
for lang in sorted(lX.keys()):
|
||||||
|
for i in range(0, len(lX[lang]), batch_size):
|
||||||
|
if i + batch_size > len(lX[lang]):
|
||||||
|
batch = lX[lang][i:len(lX[lang])]
|
||||||
|
else:
|
||||||
|
batch = lX[lang][i:i + batch_size]
|
||||||
|
max_pad_len = define_pad_length(batch)
|
||||||
|
batch = pad(batch, pad_index='101', max_pad_length=max_pad_len) # TODO: check pad index!
|
||||||
|
batch = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu')
|
||||||
|
_, output = self.forward(batch)
|
||||||
|
doc_embeds = output[-1][:, 0, :]
|
||||||
|
l_embed[lang].append(doc_embeds.cpu())
|
||||||
|
for k, v in l_embed.items():
|
||||||
|
l_embed[k] = torch.cat(v, dim=0).numpy()
|
||||||
|
return l_embed
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reconstruct_dict(predictions, y, batch_langs):
|
def _reconstruct_dict(predictions, y, batch_langs):
|
||||||
reconstructed_x = {lang: [] for lang in set(batch_langs)}
|
reconstructed_x = {lang: [] for lang in set(batch_langs)}
|
||||||
|
|
|
||||||
|
|
@ -137,9 +137,9 @@ class RecurrentModel(pl.LightningModule):
|
||||||
output = output[-1, :, :]
|
output = output[-1, :, :]
|
||||||
output = F.relu(self.linear0(output))
|
output = F.relu(self.linear0(output))
|
||||||
output = self.dropout(F.relu(self.linear1(output)))
|
output = self.dropout(F.relu(self.linear1(output)))
|
||||||
l_embed[lang].append(output)
|
l_embed[lang].append(output.cpu())
|
||||||
for k, v in l_embed.items():
|
for k, v in l_embed.items():
|
||||||
l_embed[k] = torch.cat(v, dim=0).cpu().numpy()
|
l_embed[k] = torch.cat(v, dim=0).numpy()
|
||||||
return l_embed
|
return l_embed
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx):
|
||||||
|
|
|
||||||
|
|
@ -164,6 +164,9 @@ class MultilingualIndex:
|
||||||
def l_test_raw_index(self):
|
def l_test_raw_index(self):
|
||||||
return {l: index.test_raw for l, index in self.l_index.items()}
|
return {l: index.test_raw for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_devel_raw_index(self):
|
||||||
|
return {l: index.devel_raw for l, index in self.l_index.items()}
|
||||||
|
|
||||||
def l_val_target(self):
|
def l_val_target(self):
|
||||||
return {l: index.val_target for l, index in self.l_index.items()}
|
return {l: index.val_target for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
|
@ -197,6 +200,9 @@ class MultilingualIndex:
|
||||||
def l_test_raw(self):
|
def l_test_raw(self):
|
||||||
return self.l_test_raw_index(), self.l_test_target()
|
return self.l_test_raw_index(), self.l_test_target()
|
||||||
|
|
||||||
|
def l_devel_raw(self):
|
||||||
|
return self.l_devel_raw_index(), self.l_devel_target()
|
||||||
|
|
||||||
def get_l_pad_index(self):
|
def get_l_pad_index(self):
|
||||||
return {l: index.get_pad_index() for l, index in self.l_index.items()}
|
return {l: index.get_pad_index() for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from util.common import TfidfVectorizerMultilingual, _normalize
|
||||||
from models.pl_gru import RecurrentModel
|
from models.pl_gru import RecurrentModel
|
||||||
from models.pl_bert import BertModel
|
from models.pl_bert import BertModel
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from data.datamodule import RecurrentDataModule, BertDataModule
|
from data.datamodule import RecurrentDataModule, BertDataModule, tokenize
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
|
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
@ -271,14 +271,14 @@ class BertGen(ViewGen):
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
# lX is raw text data. It has to be first indexed via Bert Tokenizer.
|
# lX is raw text data. It has to be first indexed via Bert Tokenizer.
|
||||||
data = 'TOKENIZE THIS'
|
data = self.multilingualIndex.l_devel_raw_index()
|
||||||
|
data = tokenize(data, max_len=512)
|
||||||
self.model.to('cuda' if self.gpus else 'cpu')
|
self.model.to('cuda' if self.gpus else 'cpu')
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
time_init = time()
|
time_init = time()
|
||||||
l_emebds = self.model.encode(data) # TODO
|
l_emebds = self.model.encode(data, batch_size=64)
|
||||||
transform_time = round(time() - time_init, 3)
|
transform_time = round(time() - time_init, 3)
|
||||||
print(f'Executed! Transform took: {transform_time}')
|
print(f'Executed! Transform took: {transform_time}')
|
||||||
exit('BERT VIEWGEN TRANSFORM NOT IMPLEMENTED!')
|
|
||||||
return l_emebds
|
return l_emebds
|
||||||
|
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue