fixed train validation splitting
This commit is contained in:
parent
c793bd30b5
commit
ba913f770a
|
|
@ -10,6 +10,7 @@ from util.evaluation import evaluate
|
||||||
from util.early_stop import EarlyStopping
|
from util.early_stop import EarlyStopping
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
from copy import deepcopy
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,9 +19,11 @@ def get_model(n_out):
|
||||||
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=n_out)
|
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=n_out)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def set_method_name():
|
def set_method_name():
|
||||||
return 'mBERT'
|
return 'mBERT'
|
||||||
|
|
||||||
|
|
||||||
def init_optimizer(model, lr):
|
def init_optimizer(model, lr):
|
||||||
# return AdamW(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
|
# return AdamW(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ['bias', 'LayerNorm.weight']
|
||||||
|
|
@ -35,18 +38,22 @@ def init_optimizer(model, lr):
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def init_logfile(method_name, opt):
|
def init_logfile(method_name, opt):
|
||||||
logfile = CSVLog(opt.log_file, ['dataset', 'method', 'epoch', 'measure', 'value', 'run', 'timelapse'])
|
logfile = CSVLog(opt.log_file, ['dataset', 'method', 'epoch', 'measure', 'value', 'run', 'timelapse'])
|
||||||
logfile.set_default('dataset', opt.dataset)
|
logfile.set_default('dataset', opt.dataset)
|
||||||
logfile.set_default('run', opt.seed)
|
logfile.set_default('run', opt.seed)
|
||||||
logfile.set_default('method', method_name)
|
logfile.set_default('method', method_name)
|
||||||
assert opt.force or not logfile.already_calculated(), f'results for dataset {opt.dataset} method {method_name} and run {opt.seed} already calculated'
|
assert opt.force or not logfile.already_calculated(), f'results for dataset {opt.dataset} method {method_name} ' \
|
||||||
|
f'and run {opt.seed} already calculated'
|
||||||
return logfile
|
return logfile
|
||||||
|
|
||||||
|
|
||||||
def get_lr(optimizer):
|
def get_lr(optimizer):
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
return param_group['lr']
|
return param_group['lr']
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_name(datapath):
|
def get_dataset_name(datapath):
|
||||||
possible_splits = [str(i) for i in range(10)]
|
possible_splits = [str(i) for i in range(10)]
|
||||||
splitted = datapath.split('_')
|
splitted = datapath.split('_')
|
||||||
|
|
@ -55,9 +62,10 @@ def get_dataset_name(datapath):
|
||||||
dataset_name = splitted[0].split('/')[-1]
|
dataset_name = splitted[0].split('/')[-1]
|
||||||
return f'{dataset_name}_run{id_split}'
|
return f'{dataset_name}_run{id_split}'
|
||||||
|
|
||||||
|
|
||||||
def load_datasets(datapath):
|
def load_datasets(datapath):
|
||||||
data = MultilingualDataset.load(datapath)
|
data = MultilingualDataset.load(datapath)
|
||||||
data.set_view(languages=['nl']) # Testing with just two langs
|
# data.set_view(languages=['nl', 'fr']) # Testing with less langs
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
|
|
||||||
l_devel_raw, l_devel_target = data.training(target_as_csr=False)
|
l_devel_raw, l_devel_target = data.training(target_as_csr=False)
|
||||||
|
|
@ -75,7 +83,7 @@ def do_tokenization(l_dataset, max_len=512):
|
||||||
l_tokenized[lang] = tokenizer(l_dataset[lang],
|
l_tokenized[lang] = tokenizer(l_dataset[lang],
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_len,
|
max_length=max_len,
|
||||||
add_special_tokens=True,
|
# add_special_tokens=True,
|
||||||
padding='max_length')
|
padding='max_length')
|
||||||
return l_tokenized
|
return l_tokenized
|
||||||
|
|
||||||
|
|
@ -105,7 +113,6 @@ class TrainingDataset(Dataset):
|
||||||
self.labels = np.vstack((self.labels, _labels))
|
self.labels = np.vstack((self.labels, _labels))
|
||||||
self.lang_index = np.concatenate((self.lang_index, _lang_value))
|
self.lang_index = np.concatenate((self.lang_index, _lang_value))
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
|
|
@ -115,19 +122,20 @@ class TrainingDataset(Dataset):
|
||||||
lang = self.lang_index[idx]
|
lang = self.lang_index[idx]
|
||||||
|
|
||||||
return x, torch.tensor(y, dtype=torch.float), lang
|
return x, torch.tensor(y, dtype=torch.float), lang
|
||||||
# return x, y, lang
|
|
||||||
|
|
||||||
def get_lang_ids(self):
|
def get_lang_ids(self):
|
||||||
return self.lang_ids
|
return self.lang_ids
|
||||||
|
|
||||||
|
|
||||||
def freeze_encoder(model):
|
def freeze_encoder(model):
|
||||||
for param in model.base_model.parameters():
|
for param in model.base_model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def check_param_grad_status(model):
|
def check_param_grad_status(model):
|
||||||
print('#'*50)
|
print('#'*50)
|
||||||
print('Model paramater status')
|
print('Model paramater status:')
|
||||||
for name, child in model.named_children():
|
for name, child in model.named_children():
|
||||||
trainable = False
|
trainable = False
|
||||||
for param in child.parameters():
|
for param in child.parameters():
|
||||||
|
|
@ -139,9 +147,9 @@ def check_param_grad_status(model):
|
||||||
print(f'{name} is not frozen')
|
print(f'{name} is not frozen')
|
||||||
print('#'*50)
|
print('#'*50)
|
||||||
|
|
||||||
|
|
||||||
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile):
|
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile):
|
||||||
_dataset_path = opt.dataset.split('/')[-1].split('_')
|
_dataset_path = opt.dataset.split('/')[-1].split('_')
|
||||||
# dataset_id = 'RCV1/2_run0_newBert'
|
|
||||||
dataset_id = _dataset_path[0] + _dataset_path[-1]
|
dataset_id = _dataset_path[0] + _dataset_path[-1]
|
||||||
|
|
||||||
loss_history = []
|
loss_history = []
|
||||||
|
|
@ -159,12 +167,13 @@ def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit,
|
||||||
if idx % opt.log_interval == 0:
|
if idx % opt.log_interval == 0:
|
||||||
interval_loss = np.mean(loss_history[-opt.log_interval:])
|
interval_loss = np.mean(loss_history[-opt.log_interval:])
|
||||||
print(
|
print(
|
||||||
f'{dataset_id} {method_name} Epoch: {epoch}, Step: {idx}, lr={get_lr(optim):.5f}, Training Loss: {interval_loss:.6f}')
|
f'{dataset_id} {method_name} Epoch: {epoch}, Step: {idx}, lr={get_lr(optim):.6f}, Training Loss: {interval_loss:.6f}')
|
||||||
|
|
||||||
mean_loss = np.mean(interval_loss)
|
mean_loss = np.mean(interval_loss)
|
||||||
logfile.add_row(epoch=epoch, measure='tr_loss', value=mean_loss, timelapse=time() - tinit)
|
logfile.add_row(epoch=epoch, measure='tr_loss', value=mean_loss, timelapse=time() - tinit)
|
||||||
return mean_loss
|
return mean_loss
|
||||||
|
|
||||||
|
|
||||||
def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, measure_prefix):
|
def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, measure_prefix):
|
||||||
print('# Validating model ...')
|
print('# Validating model ...')
|
||||||
loss_history = []
|
loss_history = []
|
||||||
|
|
@ -181,7 +190,7 @@ def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, mea
|
||||||
prediction = predict(logits)
|
prediction = predict(logits)
|
||||||
loss_history.append(loss)
|
loss_history.append(loss)
|
||||||
|
|
||||||
# Assigning prediction to dict in predictionS and yte_stacked according to lang_idx
|
# Assigning prediction to dict in predictions and yte_stacked according to lang_idx
|
||||||
for i, pred in enumerate(prediction):
|
for i, pred in enumerate(prediction):
|
||||||
lang_pred = id_2_lang[lang_idx.numpy()[i]]
|
lang_pred = id_2_lang[lang_idx.numpy()[i]]
|
||||||
predictions[lang_pred].append(pred)
|
predictions[lang_pred].append(pred)
|
||||||
|
|
@ -208,19 +217,20 @@ def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, mea
|
||||||
|
|
||||||
return Mf1
|
return Mf1
|
||||||
|
|
||||||
|
|
||||||
def get_tr_val_split(l_tokenized_tr, l_devel_target, val_prop, max_val, seed):
|
def get_tr_val_split(l_tokenized_tr, l_devel_target, val_prop, max_val, seed):
|
||||||
l_split_va = l_tokenized_tr
|
l_split_va = deepcopy(l_tokenized_tr)
|
||||||
l_split_val_target = {l: [] for l in l_tokenized_tr.keys()}
|
l_split_val_target = {l: [] for l in l_tokenized_tr.keys()}
|
||||||
l_split_tr = l_tokenized_tr
|
l_split_tr = deepcopy(l_tokenized_tr)
|
||||||
l_split_tr_target = {l: [] for l in l_tokenized_tr.keys()}
|
l_split_tr_target = {l: [] for l in l_tokenized_tr.keys()}
|
||||||
|
|
||||||
for lang in l_tokenized_tr.keys():
|
for lang in l_tokenized_tr.keys():
|
||||||
val_size = int(min(len(l_tokenized_tr[lang]['input_ids']) * val_prop, max_val))
|
val_size = int(min(len(l_tokenized_tr[lang]['input_ids']) * val_prop, max_val))
|
||||||
|
|
||||||
l_split_tr[lang]['input_ids'], l_split_va[lang]['input_ids'], l_split_tr_target[lang], l_split_val_target[lang] = \
|
l_split_tr[lang]['input_ids'], l_split_va[lang]['input_ids'], l_split_tr_target[lang], l_split_val_target[lang] = \
|
||||||
train_test_split(l_tokenized_tr[lang]['input_ids'], l_devel_target[lang], test_size=val_size, random_state=seed, shuffle=True)
|
train_test_split(l_tokenized_tr[lang]['input_ids'], l_devel_target[lang], test_size=val_size, random_state=seed, shuffle=True)
|
||||||
|
|
||||||
return l_split_tr, l_split_tr_target, l_split_va, l_split_val_target
|
return l_split_tr, l_split_tr_target, l_split_va, l_split_val_target
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print('Running main ...')
|
print('Running main ...')
|
||||||
|
|
@ -232,7 +242,9 @@ def main():
|
||||||
l_devel_raw, l_devel_target, l_test_raw, l_test_target = load_datasets(DATAPATH)
|
l_devel_raw, l_devel_target, l_test_raw, l_test_target = load_datasets(DATAPATH)
|
||||||
l_tokenized_tr = do_tokenization(l_devel_raw, max_len=512)
|
l_tokenized_tr = do_tokenization(l_devel_raw, max_len=512)
|
||||||
|
|
||||||
l_split_tr, l_split_tr_target, l_split_va, l_split_val_target = get_tr_val_split(l_tokenized_tr, l_devel_target, val_prop=0.2, max_val=2000, seed=opt.seed)
|
l_split_tr, l_split_tr_target, l_split_va, l_split_val_target = get_tr_val_split(l_tokenized_tr, l_devel_target,
|
||||||
|
val_prop=0.2, max_val=2000,
|
||||||
|
seed=opt.seed)
|
||||||
|
|
||||||
l_tokenized_te = do_tokenization(l_test_raw, max_len=512)
|
l_tokenized_te = do_tokenization(l_test_raw, max_len=512)
|
||||||
|
|
||||||
|
|
@ -264,10 +276,10 @@ def main():
|
||||||
lang_ids = va_dataset.lang_ids
|
lang_ids = va_dataset.lang_ids
|
||||||
for epoch in range(1, opt.nepochs+1):
|
for epoch in range(1, opt.nepochs+1):
|
||||||
print('# Start Training ...')
|
print('# Start Training ...')
|
||||||
train(model, tr_dataloader, epoch, criterion, optim, 'TestingBert', tinit, logfile)
|
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile)
|
||||||
# lr_scheduler.step(epoch=None) # reduces the learning rate
|
# lr_scheduler.step(epoch=None) # reduces the learning rate
|
||||||
|
|
||||||
# validation
|
# Validation
|
||||||
macrof1 = test(model, va_dataloader, lang_ids, tinit, epoch, logfile, criterion, 'va')
|
macrof1 = test(model, va_dataloader, lang_ids, tinit, epoch, logfile, criterion, 'va')
|
||||||
early_stop(macrof1, epoch)
|
early_stop(macrof1, epoch)
|
||||||
if opt.test_each>0:
|
if opt.test_each>0:
|
||||||
|
|
@ -279,7 +291,7 @@ def main():
|
||||||
if not opt.plotmode:
|
if not opt.plotmode:
|
||||||
break
|
break
|
||||||
|
|
||||||
if opt.plotmode==False:
|
if not opt.plotmode:
|
||||||
print('-' * 80)
|
print('-' * 80)
|
||||||
print('Training over. Performing final evaluation')
|
print('Training over. Performing final evaluation')
|
||||||
|
|
||||||
|
|
@ -288,7 +300,7 @@ def main():
|
||||||
if opt.val_epochs>0:
|
if opt.val_epochs>0:
|
||||||
print(f'running last {opt.val_epochs} training epochs on the validation set')
|
print(f'running last {opt.val_epochs} training epochs on the validation set')
|
||||||
for val_epoch in range(1, opt.val_epochs + 1):
|
for val_epoch in range(1, opt.val_epochs + 1):
|
||||||
train(model, va_dataloader, epoch+val_epoch, criterion, optim, 'TestingBert', tinit, logfile)
|
train(model, va_dataloader, epoch+val_epoch, criterion, optim, method_name, tinit, logfile)
|
||||||
|
|
||||||
# final test
|
# final test
|
||||||
print('Training complete: testing')
|
print('Training complete: testing')
|
||||||
|
|
@ -334,4 +346,4 @@ if __name__ == '__main__':
|
||||||
opt.patience = 5
|
opt.patience = 5
|
||||||
|
|
||||||
main()
|
main()
|
||||||
# TODO: refactor .cuda() -> .to(device) in order to check if the process is faster on CPU given the bigger batch size
|
# TODO: refactor .cuda() -> .to(device) in order to check if the process is faster on CPU given the bigger batch size
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
import torch
|
import torch
|
||||||
from time import time
|
from time import time
|
||||||
from util.file import create_if_not_exist
|
from util.file import create_if_not_exist
|
||||||
|
import warnings
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
|
|
||||||
|
|
@ -19,9 +19,10 @@ class EarlyStopping:
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.STOP = False
|
self.STOP = False
|
||||||
|
|
||||||
def __call__(self, watch_score, epoch): #model
|
def __call__(self, watch_score, epoch):
|
||||||
|
|
||||||
if self.STOP: return #done
|
if self.STOP:
|
||||||
|
return
|
||||||
|
|
||||||
if self.best_score is None or watch_score >= self.best_score:
|
if self.best_score is None or watch_score >= self.best_score:
|
||||||
self.best_score = watch_score
|
self.best_score = watch_score
|
||||||
|
|
@ -29,10 +30,12 @@ class EarlyStopping:
|
||||||
self.stop_time = time()
|
self.stop_time = time()
|
||||||
if self.checkpoint:
|
if self.checkpoint:
|
||||||
self.print(f'[early-stop] improved, saving model in {self.checkpoint}')
|
self.print(f'[early-stop] improved, saving model in {self.checkpoint}')
|
||||||
torch.save(self.model, self.checkpoint)
|
with warnings.catch_warnings():
|
||||||
# with open(self.checkpoint)
|
warnings.simplefilter("ignore")
|
||||||
# torch.save({'state_dict': self.model.state_dict(),
|
torch.save(self.model, self.checkpoint)
|
||||||
# 'optimizer_state_dict': self.optimizer.state_dict()}, self.checkpoint)
|
# with open(self.checkpoint)
|
||||||
|
# torch.save({'state_dict': self.model.state_dict(),
|
||||||
|
# 'optimizer_state_dict': self.optimizer.state_dict()}, self.checkpoint)
|
||||||
else:
|
else:
|
||||||
self.print(f'[early-stop] improved')
|
self.print(f'[early-stop] improved')
|
||||||
self.patience = self.patience_limit
|
self.patience = self.patience_limit
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue