fixed train validation splitting

This commit is contained in:
andrea 2020-07-27 17:15:06 +02:00
parent c793bd30b5
commit ba913f770a
2 changed files with 41 additions and 26 deletions

View File

@ -10,6 +10,7 @@ from util.evaluation import evaluate
from util.early_stop import EarlyStopping
from torch.optim.lr_scheduler import StepLR
from sklearn.model_selection import train_test_split
from copy import deepcopy
import argparse
@ -18,9 +19,11 @@ def get_model(n_out):
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=n_out)
return model
def set_method_name():
return 'mBERT'
def init_optimizer(model, lr):
# return AdamW(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
no_decay = ['bias', 'LayerNorm.weight']
@ -35,18 +38,22 @@ def init_optimizer(model, lr):
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
return optimizer
def init_logfile(method_name, opt):
logfile = CSVLog(opt.log_file, ['dataset', 'method', 'epoch', 'measure', 'value', 'run', 'timelapse'])
logfile.set_default('dataset', opt.dataset)
logfile.set_default('run', opt.seed)
logfile.set_default('method', method_name)
assert opt.force or not logfile.already_calculated(), f'results for dataset {opt.dataset} method {method_name} and run {opt.seed} already calculated'
assert opt.force or not logfile.already_calculated(), f'results for dataset {opt.dataset} method {method_name} ' \
f'and run {opt.seed} already calculated'
return logfile
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def get_dataset_name(datapath):
possible_splits = [str(i) for i in range(10)]
splitted = datapath.split('_')
@ -55,9 +62,10 @@ def get_dataset_name(datapath):
dataset_name = splitted[0].split('/')[-1]
return f'{dataset_name}_run{id_split}'
def load_datasets(datapath):
data = MultilingualDataset.load(datapath)
data.set_view(languages=['nl']) # Testing with just two langs
# data.set_view(languages=['nl', 'fr']) # Testing with less langs
data.show_dimensions()
l_devel_raw, l_devel_target = data.training(target_as_csr=False)
@ -75,7 +83,7 @@ def do_tokenization(l_dataset, max_len=512):
l_tokenized[lang] = tokenizer(l_dataset[lang],
truncation=True,
max_length=max_len,
add_special_tokens=True,
# add_special_tokens=True,
padding='max_length')
return l_tokenized
@ -105,7 +113,6 @@ class TrainingDataset(Dataset):
self.labels = np.vstack((self.labels, _labels))
self.lang_index = np.concatenate((self.lang_index, _lang_value))
def __len__(self):
return len(self.data)
@ -115,19 +122,20 @@ class TrainingDataset(Dataset):
lang = self.lang_index[idx]
return x, torch.tensor(y, dtype=torch.float), lang
# return x, y, lang
def get_lang_ids(self):
return self.lang_ids
def freeze_encoder(model):
for param in model.base_model.parameters():
param.requires_grad = False
return model
def check_param_grad_status(model):
print('#'*50)
print('Model paramater status')
print('Model paramater status:')
for name, child in model.named_children():
trainable = False
for param in child.parameters():
@ -139,9 +147,9 @@ def check_param_grad_status(model):
print(f'{name} is not frozen')
print('#'*50)
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile):
_dataset_path = opt.dataset.split('/')[-1].split('_')
# dataset_id = 'RCV1/2_run0_newBert'
dataset_id = _dataset_path[0] + _dataset_path[-1]
loss_history = []
@ -159,12 +167,13 @@ def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit,
if idx % opt.log_interval == 0:
interval_loss = np.mean(loss_history[-opt.log_interval:])
print(
f'{dataset_id} {method_name} Epoch: {epoch}, Step: {idx}, lr={get_lr(optim):.5f}, Training Loss: {interval_loss:.6f}')
f'{dataset_id} {method_name} Epoch: {epoch}, Step: {idx}, lr={get_lr(optim):.6f}, Training Loss: {interval_loss:.6f}')
mean_loss = np.mean(interval_loss)
logfile.add_row(epoch=epoch, measure='tr_loss', value=mean_loss, timelapse=time() - tinit)
return mean_loss
def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, measure_prefix):
print('# Validating model ...')
loss_history = []
@ -181,7 +190,7 @@ def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, mea
prediction = predict(logits)
loss_history.append(loss)
# Assigning prediction to dict in predictionS and yte_stacked according to lang_idx
# Assigning prediction to dict in predictions and yte_stacked according to lang_idx
for i, pred in enumerate(prediction):
lang_pred = id_2_lang[lang_idx.numpy()[i]]
predictions[lang_pred].append(pred)
@ -208,19 +217,20 @@ def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, mea
return Mf1
def get_tr_val_split(l_tokenized_tr, l_devel_target, val_prop, max_val, seed):
l_split_va = l_tokenized_tr
l_split_va = deepcopy(l_tokenized_tr)
l_split_val_target = {l: [] for l in l_tokenized_tr.keys()}
l_split_tr = l_tokenized_tr
l_split_tr = deepcopy(l_tokenized_tr)
l_split_tr_target = {l: [] for l in l_tokenized_tr.keys()}
for lang in l_tokenized_tr.keys():
val_size = int(min(len(l_tokenized_tr[lang]['input_ids']) * val_prop, max_val))
l_split_tr[lang]['input_ids'], l_split_va[lang]['input_ids'], l_split_tr_target[lang], l_split_val_target[lang] = \
train_test_split(l_tokenized_tr[lang]['input_ids'], l_devel_target[lang], test_size=val_size, random_state=seed, shuffle=True)
return l_split_tr, l_split_tr_target, l_split_va, l_split_val_target
return l_split_tr, l_split_tr_target, l_split_va, l_split_val_target
def main():
print('Running main ...')
@ -232,7 +242,9 @@ def main():
l_devel_raw, l_devel_target, l_test_raw, l_test_target = load_datasets(DATAPATH)
l_tokenized_tr = do_tokenization(l_devel_raw, max_len=512)
l_split_tr, l_split_tr_target, l_split_va, l_split_val_target = get_tr_val_split(l_tokenized_tr, l_devel_target, val_prop=0.2, max_val=2000, seed=opt.seed)
l_split_tr, l_split_tr_target, l_split_va, l_split_val_target = get_tr_val_split(l_tokenized_tr, l_devel_target,
val_prop=0.2, max_val=2000,
seed=opt.seed)
l_tokenized_te = do_tokenization(l_test_raw, max_len=512)
@ -264,10 +276,10 @@ def main():
lang_ids = va_dataset.lang_ids
for epoch in range(1, opt.nepochs+1):
print('# Start Training ...')
train(model, tr_dataloader, epoch, criterion, optim, 'TestingBert', tinit, logfile)
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile)
# lr_scheduler.step(epoch=None) # reduces the learning rate
# validation
# Validation
macrof1 = test(model, va_dataloader, lang_ids, tinit, epoch, logfile, criterion, 'va')
early_stop(macrof1, epoch)
if opt.test_each>0:
@ -279,7 +291,7 @@ def main():
if not opt.plotmode:
break
if opt.plotmode==False:
if not opt.plotmode:
print('-' * 80)
print('Training over. Performing final evaluation')
@ -288,7 +300,7 @@ def main():
if opt.val_epochs>0:
print(f'running last {opt.val_epochs} training epochs on the validation set')
for val_epoch in range(1, opt.val_epochs + 1):
train(model, va_dataloader, epoch+val_epoch, criterion, optim, 'TestingBert', tinit, logfile)
train(model, va_dataloader, epoch+val_epoch, criterion, optim, method_name, tinit, logfile)
# final test
print('Training complete: testing')
@ -334,4 +346,4 @@ if __name__ == '__main__':
opt.patience = 5
main()
# TODO: refactor .cuda() -> .to(device) in order to check if the process is faster on CPU given the bigger batch size
# TODO: refactor .cuda() -> .to(device) in order to check if the process is faster on CPU given the bigger batch size

View File

@ -2,7 +2,7 @@
import torch
from time import time
from util.file import create_if_not_exist
import warnings
class EarlyStopping:
@ -19,9 +19,10 @@ class EarlyStopping:
self.optimizer = optimizer
self.STOP = False
def __call__(self, watch_score, epoch): #model
def __call__(self, watch_score, epoch):
if self.STOP: return #done
if self.STOP:
return
if self.best_score is None or watch_score >= self.best_score:
self.best_score = watch_score
@ -29,10 +30,12 @@ class EarlyStopping:
self.stop_time = time()
if self.checkpoint:
self.print(f'[early-stop] improved, saving model in {self.checkpoint}')
torch.save(self.model, self.checkpoint)
# with open(self.checkpoint)
# torch.save({'state_dict': self.model.state_dict(),
# 'optimizer_state_dict': self.optimizer.state_dict()}, self.checkpoint)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
torch.save(self.model, self.checkpoint)
# with open(self.checkpoint)
# torch.save({'state_dict': self.model.state_dict(),
# 'optimizer_state_dict': self.optimizer.state_dict()}, self.checkpoint)
else:
self.print(f'[early-stop] improved')
self.patience = self.patience_limit