implemented BertDataModule collate function
This commit is contained in:
parent
b2be446446
commit
f579a1a7f2
|
|
@ -181,7 +181,7 @@ class BertDataModule(RecurrentDataModule):
|
||||||
Pytorch Lightning Datamodule to be deployed with BertGen.
|
Pytorch Lightning Datamodule to be deployed with BertGen.
|
||||||
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
"""
|
"""
|
||||||
def __init__(self, multilingualIndex, batchsize=64, max_len=512, zero_shot=False, zscl_langs=None):
|
def __init__(self, multilingualIndex, batchsize=64, max_len=512, zero_shot=False, zscl_langs=None, debug=False):
|
||||||
"""
|
"""
|
||||||
Init BertDataModule.
|
Init BertDataModule.
|
||||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
|
@ -196,28 +196,33 @@ class BertDataModule(RecurrentDataModule):
|
||||||
zscl_langs = []
|
zscl_langs = []
|
||||||
self.zero_shot = zero_shot
|
self.zero_shot = zero_shot
|
||||||
self.train_langs = zscl_langs
|
self.train_langs = zscl_langs
|
||||||
|
self.debug = debug
|
||||||
|
if self.debug:
|
||||||
|
print('\n[Running on DEBUG mode - samples per language are reduced to 50 max!]\n')
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw_zero_shot(langs=self.train_langs) # todo: check this!
|
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw_zero_shot(langs=self.train_langs)
|
||||||
else:
|
else:
|
||||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||||
# Debug settings: reducing number of samples
|
if self.debug:
|
||||||
# l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
# Debug settings: reducing number of samples
|
||||||
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
l_train_raw = {l: train[:50] for l, train in l_train_raw.items()}
|
||||||
|
l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||||
|
|
||||||
l_train_index = tokenize(l_train_raw, max_len=self.max_len)
|
l_train_index = tokenize(l_train_raw, max_len=self.max_len)
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw_zero_shot(langs=self.train_langs) # todo: check this!
|
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw_zero_shot(langs=self.train_langs)
|
||||||
else:
|
else:
|
||||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
||||||
# Debug settings: reducing number of samples
|
if self.debug:
|
||||||
# l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
# Debug settings: reducing number of samples
|
||||||
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
l_val_raw = {l: train[:50] for l, train in l_val_raw.items()}
|
||||||
|
l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||||
|
|
||||||
l_val_index = tokenize(l_val_raw, max_len=self.max_len)
|
l_val_index = tokenize(l_val_raw, max_len=self.max_len)
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
|
|
@ -225,12 +230,13 @@ class BertDataModule(RecurrentDataModule):
|
||||||
|
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw_zero_shot(langs=self.train_langs) # todo: check this!
|
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw_zero_shot(langs=self.train_langs)
|
||||||
else:
|
else:
|
||||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||||
# Debug settings: reducing number of samples
|
if self.debug:
|
||||||
# l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
# Debug settings: reducing number of samples
|
||||||
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
l_test_raw = {l: train[:50] for l, train in l_test_raw.items()}
|
||||||
|
l_test_target = {l: target[:50] for l, target in l_test_target.items()}
|
||||||
|
|
||||||
l_test_index = tokenize(l_test_raw, max_len=self.max_len)
|
l_test_index = tokenize(l_test_raw, max_len=self.max_len)
|
||||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||||
|
|
@ -241,10 +247,16 @@ class BertDataModule(RecurrentDataModule):
|
||||||
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"
|
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return DataLoader(self.training_dataset, batch_size=self.batchsize)
|
return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(self.val_dataset, batch_size=self.batchsize)
|
return DataLoader(self.val_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return DataLoader(self.test_dataset, batch_size=self.batchsize)
|
return DataLoader(self.test_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
|
||||||
|
|
||||||
|
def collate_fn_bert(self, data):
|
||||||
|
x_batch = np.vstack([elem[0] for elem in data])
|
||||||
|
y_batch = np.vstack([elem[1] for elem in data])
|
||||||
|
lang_batch = [elem[2] for elem in data]
|
||||||
|
return torch.LongTensor(x_batch), torch.FloatTensor(y_batch), lang_batch
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class BertModel(pl.LightningModule):
|
||||||
self.macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
self.macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
self.microK = CustomK(num_classes=output_size, average='micro', device=self.gpus)
|
self.microK = CustomK(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
self.macroK = CustomK(num_classes=output_size, average='macro', device=self.gpus)
|
self.macroK = CustomK(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
# Language specific metrics to compute metrics at epoch level
|
# Language specific metrics to compute at epoch level
|
||||||
self.lang_macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
self.lang_macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
self.lang_microF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
self.lang_microF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
self.lang_macroK = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
self.lang_macroK = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
|
@ -44,9 +44,7 @@ class BertModel(pl.LightningModule):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx):
|
||||||
X, y, _, batch_langs = train_batch
|
X, y, batch_langs = train_batch
|
||||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
|
||||||
y = y.type(torch.FloatTensor)
|
|
||||||
y = y.to('cuda' if self.gpus else 'cpu')
|
y = y.to('cuda' if self.gpus else 'cpu')
|
||||||
logits, _ = self.forward(X)
|
logits, _ = self.forward(X)
|
||||||
loss = self.loss(logits, y)
|
loss = self.loss(logits, y)
|
||||||
|
|
@ -99,9 +97,7 @@ class BertModel(pl.LightningModule):
|
||||||
self.logger.experiment.add_scalars('train-langs-microK', {f'{lang}': avg_microK}, self.current_epoch)
|
self.logger.experiment.add_scalars('train-langs-microK', {f'{lang}': avg_microK}, self.current_epoch)
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, val_batch, batch_idx):
|
||||||
X, y, _, batch_langs = val_batch
|
X, y, batch_langs = val_batch
|
||||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
|
||||||
y = y.type(torch.FloatTensor)
|
|
||||||
y = y.to('cuda' if self.gpus else 'cpu')
|
y = y.to('cuda' if self.gpus else 'cpu')
|
||||||
logits, _ = self.forward(X)
|
logits, _ = self.forward(X)
|
||||||
loss = self.loss(logits, y)
|
loss = self.loss(logits, y)
|
||||||
|
|
@ -118,12 +114,10 @@ class BertModel(pl.LightningModule):
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
def test_step(self, test_batch, batch_idx):
|
def test_step(self, test_batch, batch_idx):
|
||||||
X, y, _, batch_langs = test_batch
|
X, y, batch_langs = test_batch
|
||||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
|
||||||
y = y.type(torch.FloatTensor)
|
|
||||||
y = y.to('cuda' if self.gpus else 'cpu')
|
y = y.to('cuda' if self.gpus else 'cpu')
|
||||||
logits, _ = self.forward(X)
|
logits, _ = self.forward(X)
|
||||||
loss = self.loss(logits, y)
|
# loss = self.loss(logits, y)
|
||||||
# Squashing logits through Sigmoid in order to get confidence score
|
# Squashing logits through Sigmoid in order to get confidence score
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
microF1 = self.microF1(predictions, y)
|
microF1 = self.microF1(predictions, y)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class RecurrentModel(pl.LightningModule):
|
||||||
self.macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
self.macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
self.microK = CustomK(num_classes=output_size, average='micro', device=self.gpus)
|
self.microK = CustomK(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
self.macroK = CustomK(num_classes=output_size, average='macro', device=self.gpus)
|
self.macroK = CustomK(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
# Language specific metrics to compute metrics at epoch level
|
# Language specific metrics to compute at epoch level
|
||||||
self.lang_macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
self.lang_macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
self.lang_microF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
self.lang_microF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
self.lang_macroK = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
self.lang_macroK = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
|
|
||||||
|
|
@ -474,7 +474,8 @@ class BertGen(ViewGen):
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs,
|
||||||
|
debug=True)
|
||||||
|
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue