TODO: early stop is triggered by current_score == best_score !
This commit is contained in:
parent
7b6938459f
commit
612e90a584
|
|
@ -181,7 +181,8 @@ class BertDataModule(RecurrentDataModule):
|
||||||
Pytorch Lightning Datamodule to be deployed with BertGen.
|
Pytorch Lightning Datamodule to be deployed with BertGen.
|
||||||
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
"""
|
"""
|
||||||
def __init__(self, multilingualIndex, batchsize=64, max_len=512, zero_shot=False, zscl_langs=None, debug=False):
|
def __init__(self, multilingualIndex, batchsize=64, max_len=512, zero_shot=False, zscl_langs=None, debug=False,
|
||||||
|
max_samples=50):
|
||||||
"""
|
"""
|
||||||
Init BertDataModule.
|
Init BertDataModule.
|
||||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
|
@ -197,8 +198,9 @@ class BertDataModule(RecurrentDataModule):
|
||||||
self.zero_shot = zero_shot
|
self.zero_shot = zero_shot
|
||||||
self.train_langs = zscl_langs
|
self.train_langs = zscl_langs
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
self.max_samples = max_samples
|
||||||
if self.debug:
|
if self.debug:
|
||||||
print('\n[Running on DEBUG mode - samples per language are reduced to 50 max!]\n')
|
print(f'\n[Running on DEBUG mode - samples per language are reduced to {self.max_samples} max!]\n')
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
|
|
@ -208,8 +210,8 @@ class BertDataModule(RecurrentDataModule):
|
||||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||||
if self.debug:
|
if self.debug:
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_train_raw = {l: train[:50] for l, train in l_train_raw.items()}
|
l_train_raw = {l: train[:self.max_samples] for l, train in l_train_raw.items()}
|
||||||
l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
l_train_target = {l: target[:self.max_samples] for l, target in l_train_target.items()}
|
||||||
|
|
||||||
l_train_index = tokenize(l_train_raw, max_len=self.max_len)
|
l_train_index = tokenize(l_train_raw, max_len=self.max_len)
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
|
|
@ -221,8 +223,8 @@ class BertDataModule(RecurrentDataModule):
|
||||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
||||||
if self.debug:
|
if self.debug:
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_val_raw = {l: train[:50] for l, train in l_val_raw.items()}
|
l_val_raw = {l: train[:self.max_samples] for l, train in l_val_raw.items()}
|
||||||
l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
l_val_target = {l: target[:self.max_samples] for l, target in l_val_target.items()}
|
||||||
|
|
||||||
l_val_index = tokenize(l_val_raw, max_len=self.max_len)
|
l_val_index = tokenize(l_val_raw, max_len=self.max_len)
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
|
|
@ -235,8 +237,8 @@ class BertDataModule(RecurrentDataModule):
|
||||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||||
if self.debug:
|
if self.debug:
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_test_raw = {l: train[:50] for l, train in l_test_raw.items()}
|
l_test_raw = {l: train[:self.max_samples] for l, train in l_test_raw.items()}
|
||||||
l_test_target = {l: target[:50] for l, target in l_test_target.items()}
|
l_test_target = {l: target[:self.max_samples] for l, target in l_test_target.items()}
|
||||||
|
|
||||||
l_test_index = tokenize(l_test_raw, max_len=self.max_len)
|
l_test_index = tokenize(l_test_raw, max_len=self.max_len)
|
||||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from transformers import BertForSequenceClassification, AdamW
|
from transformers import BertForSequenceClassification, AdamW
|
||||||
|
import numpy as np
|
||||||
|
import csv
|
||||||
|
|
||||||
from src.util.common import define_pad_length, pad
|
from src.util.common import define_pad_length, pad
|
||||||
from src.util.pl_metrics import CustomF1, CustomK
|
from src.util.pl_metrics import CustomF1, CustomK
|
||||||
|
|
@ -9,7 +11,7 @@ from src.util.pl_metrics import CustomF1, CustomK
|
||||||
|
|
||||||
class BertModel(pl.LightningModule):
|
class BertModel(pl.LightningModule):
|
||||||
|
|
||||||
def __init__(self, output_size, stored_path, gpus=None):
|
def __init__(self, output_size, stored_path, gpus=None, manual_log=False):
|
||||||
"""
|
"""
|
||||||
Init Bert model.
|
Init Bert model.
|
||||||
:param output_size:
|
:param output_size:
|
||||||
|
|
@ -39,6 +41,17 @@ class BertModel(pl.LightningModule):
|
||||||
output_hidden_states=True)
|
output_hidden_states=True)
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
# Manual logging settings
|
||||||
|
self.manual_log = manual_log
|
||||||
|
if self.manual_log:
|
||||||
|
from src.util.file import create_if_not_exist
|
||||||
|
self.csv_file = f'csv_logs/bert/bert_manual_log_v{self._version}.csv'
|
||||||
|
with open(self.csv_file, 'x') as handler:
|
||||||
|
writer = csv.writer(handler, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
writer.writerow(['tr_loss', 'va_loss', 'va_macroF1', 'va_microF1', 'va_macroK', 'va_microK'])
|
||||||
|
self.csv_metrics = {'tr_loss': [], 'va_loss': [], 'va_macroF1': [],
|
||||||
|
'va_microF1': [], 'va_macroK': [], 'va_microK': []}
|
||||||
|
|
||||||
def forward(self, X):
|
def forward(self, X):
|
||||||
logits = self.bert(X)
|
logits = self.bert(X)
|
||||||
return logits
|
return logits
|
||||||
|
|
@ -54,11 +67,11 @@ class BertModel(pl.LightningModule):
|
||||||
macroF1 = self.macroF1(predictions, y)
|
macroF1 = self.macroF1(predictions, y)
|
||||||
microK = self.microK(predictions, y)
|
microK = self.microK(predictions, y)
|
||||||
macroK = self.macroK(predictions, y)
|
macroK = self.macroK(predictions, y)
|
||||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log('train-macroF1', macroF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-microF1', microF1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-macroK', macroK, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-macroK', macroK, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-microK', microK, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-microK', microK, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
lX, ly = self._reconstruct_dict(predictions, y, batch_langs)
|
lX, ly = self._reconstruct_dict(predictions, y, batch_langs)
|
||||||
return {'loss': loss, 'pred': lX, 'target': ly}
|
return {'loss': loss, 'pred': lX, 'target': ly}
|
||||||
|
|
||||||
|
|
@ -96,6 +109,12 @@ class BertModel(pl.LightningModule):
|
||||||
self.logger.experiment.add_scalars('train-langs-macroK', {f'{lang}': avg_macroK}, self.current_epoch)
|
self.logger.experiment.add_scalars('train-langs-macroK', {f'{lang}': avg_macroK}, self.current_epoch)
|
||||||
self.logger.experiment.add_scalars('train-langs-microK', {f'{lang}': avg_microK}, self.current_epoch)
|
self.logger.experiment.add_scalars('train-langs-microK', {f'{lang}': avg_microK}, self.current_epoch)
|
||||||
|
|
||||||
|
if self.manual_log:
|
||||||
|
# Manual logging epoch loss
|
||||||
|
tr_epoch_loss = np.average([out['loss'].item() for out in outputs])
|
||||||
|
self.csv_metrics['tr_loss'].append(tr_epoch_loss)
|
||||||
|
self.save_manual_logs()
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, val_batch, batch_idx):
|
||||||
X, y, batch_langs = val_batch
|
X, y, batch_langs = val_batch
|
||||||
y = y.to('cuda' if self.gpus else 'cpu')
|
y = y.to('cuda' if self.gpus else 'cpu')
|
||||||
|
|
@ -106,11 +125,20 @@ class BertModel(pl.LightningModule):
|
||||||
macroF1 = self.macroF1(predictions, y)
|
macroF1 = self.macroF1(predictions, y)
|
||||||
microK = self.microK(predictions, y)
|
microK = self.microK(predictions, y)
|
||||||
macroK = self.macroK(predictions, y)
|
macroK = self.macroK(predictions, y)
|
||||||
self.log('val-loss', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('val-loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('val-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log('val-macroK', macroK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('val-macroK', macroK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log('val-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('val-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
|
||||||
|
if self.manual_log:
|
||||||
|
# Manual logging to csv
|
||||||
|
self.csv_metrics['va_loss'].append(loss.item())
|
||||||
|
self.csv_metrics['va_macroF1'].append(macroF1.item())
|
||||||
|
self.csv_metrics['va_microF1'].append(microF1.item())
|
||||||
|
self.csv_metrics['va_macroK'].append(macroK.item())
|
||||||
|
self.csv_metrics['va_microK'].append(microK.item())
|
||||||
|
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
def test_step(self, test_batch, batch_idx):
|
def test_step(self, test_batch, batch_idx):
|
||||||
|
|
@ -126,8 +154,8 @@ class BertModel(pl.LightningModule):
|
||||||
macroK = self.macroK(predictions, y)
|
macroK = self.macroK(predictions, y)
|
||||||
self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('test-macroK', macroK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('test-macroK', macroK, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('test-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('test-microK', microK, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
def configure_optimizers(self, lr=1e-5, weight_decay=0.01):
|
def configure_optimizers(self, lr=1e-5, weight_decay=0.01):
|
||||||
|
|
@ -141,7 +169,7 @@ class BertModel(pl.LightningModule):
|
||||||
'weight_decay': weight_decay}
|
'weight_decay': weight_decay}
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
||||||
scheduler = {'scheduler': StepLR(optimizer, step_size=25, gamma=1.0), # TODO set to 1.0 to debug (prev. 0.1)
|
scheduler = {'scheduler': StepLR(optimizer, step_size=25, gamma=0.1),
|
||||||
'interval': 'epoch'}
|
'interval': 'epoch'}
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
|
@ -181,3 +209,14 @@ class BertModel(pl.LightningModule):
|
||||||
for k, v in reconstructed_y.items():
|
for k, v in reconstructed_y.items():
|
||||||
reconstructed_y[k] = torch.cat(v).view(-1, predictions.shape[1])
|
reconstructed_y[k] = torch.cat(v).view(-1, predictions.shape[1])
|
||||||
return reconstructed_x, reconstructed_y
|
return reconstructed_x, reconstructed_y
|
||||||
|
|
||||||
|
def save_manual_logs(self):
|
||||||
|
if self.global_step == 0:
|
||||||
|
return
|
||||||
|
with open(self.csv_file, 'a', newline='\n') as handler:
|
||||||
|
writer = csv.writer(handler, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
writer.writerow([np.average(self.csv_metrics['tr_loss']), np.average(self.csv_metrics['va_loss']),
|
||||||
|
np.average(self.csv_metrics['va_macroF1']), np.average(self.csv_metrics['va_microF1']),
|
||||||
|
np.average(self.csv_metrics['va_macroK']), np.average(self.csv_metrics['va_microK'])])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -450,7 +450,7 @@ class BertGen(ViewGen):
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
|
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
|
||||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
patience=self.patience, verbose=False, mode='max')
|
patience=self.patience, verbose=True, mode='max')
|
||||||
|
|
||||||
# Zero shot parameters
|
# Zero shot parameters
|
||||||
self.zero_shot = zero_shot
|
self.zero_shot = zero_shot
|
||||||
|
|
@ -475,14 +475,17 @@ class BertGen(ViewGen):
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs,
|
||||||
|
debug=True, max_samples=50) # todo: debug=True -> DEBUG setting
|
||||||
|
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
trainer = Trainer(max_epochs=self.nepochs, gpus=self.gpus,
|
||||||
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False,
|
logger=self.logger,
|
||||||
overfit_batches=0.01) # todo: overfit_batches -> DEBUG setting
|
callbacks=[self.early_stop_callback],
|
||||||
|
checkpoint_callback=False)
|
||||||
|
|
||||||
trainer.fit(self.model, datamodule=bertDataModule)
|
trainer.fit(self.model, datamodule=bertDataModule)
|
||||||
trainer.test(self.model, datamodule=bertDataModule)
|
trainer.test(self.model, datamodule=bertDataModule)
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue