fixed train validation splitting
This commit is contained in:
parent
c793bd30b5
commit
ba913f770a
|
|
@ -10,6 +10,7 @@ from util.evaluation import evaluate
|
|||
from util.early_stop import EarlyStopping
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from sklearn.model_selection import train_test_split
|
||||
from copy import deepcopy
|
||||
import argparse
|
||||
|
||||
|
||||
|
|
@ -18,9 +19,11 @@ def get_model(n_out):
|
|||
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=n_out)
|
||||
return model
|
||||
|
||||
|
||||
def set_method_name():
|
||||
return 'mBERT'
|
||||
|
||||
|
||||
def init_optimizer(model, lr):
|
||||
# return AdamW(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
|
|
@ -35,18 +38,22 @@ def init_optimizer(model, lr):
|
|||
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
||||
return optimizer
|
||||
|
||||
|
||||
def init_logfile(method_name, opt):
|
||||
logfile = CSVLog(opt.log_file, ['dataset', 'method', 'epoch', 'measure', 'value', 'run', 'timelapse'])
|
||||
logfile.set_default('dataset', opt.dataset)
|
||||
logfile.set_default('run', opt.seed)
|
||||
logfile.set_default('method', method_name)
|
||||
assert opt.force or not logfile.already_calculated(), f'results for dataset {opt.dataset} method {method_name} and run {opt.seed} already calculated'
|
||||
assert opt.force or not logfile.already_calculated(), f'results for dataset {opt.dataset} method {method_name} ' \
|
||||
f'and run {opt.seed} already calculated'
|
||||
return logfile
|
||||
|
||||
|
||||
def get_lr(optimizer):
|
||||
for param_group in optimizer.param_groups:
|
||||
return param_group['lr']
|
||||
|
||||
|
||||
def get_dataset_name(datapath):
|
||||
possible_splits = [str(i) for i in range(10)]
|
||||
splitted = datapath.split('_')
|
||||
|
|
@ -55,9 +62,10 @@ def get_dataset_name(datapath):
|
|||
dataset_name = splitted[0].split('/')[-1]
|
||||
return f'{dataset_name}_run{id_split}'
|
||||
|
||||
|
||||
def load_datasets(datapath):
|
||||
data = MultilingualDataset.load(datapath)
|
||||
data.set_view(languages=['nl']) # Testing with just two langs
|
||||
# data.set_view(languages=['nl', 'fr']) # Testing with less langs
|
||||
data.show_dimensions()
|
||||
|
||||
l_devel_raw, l_devel_target = data.training(target_as_csr=False)
|
||||
|
|
@ -75,7 +83,7 @@ def do_tokenization(l_dataset, max_len=512):
|
|||
l_tokenized[lang] = tokenizer(l_dataset[lang],
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
add_special_tokens=True,
|
||||
# add_special_tokens=True,
|
||||
padding='max_length')
|
||||
return l_tokenized
|
||||
|
||||
|
|
@ -105,7 +113,6 @@ class TrainingDataset(Dataset):
|
|||
self.labels = np.vstack((self.labels, _labels))
|
||||
self.lang_index = np.concatenate((self.lang_index, _lang_value))
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
|
@ -115,19 +122,20 @@ class TrainingDataset(Dataset):
|
|||
lang = self.lang_index[idx]
|
||||
|
||||
return x, torch.tensor(y, dtype=torch.float), lang
|
||||
# return x, y, lang
|
||||
|
||||
def get_lang_ids(self):
|
||||
return self.lang_ids
|
||||
|
||||
|
||||
def freeze_encoder(model):
|
||||
for param in model.base_model.parameters():
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def check_param_grad_status(model):
|
||||
print('#'*50)
|
||||
print('Model paramater status')
|
||||
print('Model paramater status:')
|
||||
for name, child in model.named_children():
|
||||
trainable = False
|
||||
for param in child.parameters():
|
||||
|
|
@ -139,9 +147,9 @@ def check_param_grad_status(model):
|
|||
print(f'{name} is not frozen')
|
||||
print('#'*50)
|
||||
|
||||
|
||||
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile):
|
||||
_dataset_path = opt.dataset.split('/')[-1].split('_')
|
||||
# dataset_id = 'RCV1/2_run0_newBert'
|
||||
dataset_id = _dataset_path[0] + _dataset_path[-1]
|
||||
|
||||
loss_history = []
|
||||
|
|
@ -159,12 +167,13 @@ def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit,
|
|||
if idx % opt.log_interval == 0:
|
||||
interval_loss = np.mean(loss_history[-opt.log_interval:])
|
||||
print(
|
||||
f'{dataset_id} {method_name} Epoch: {epoch}, Step: {idx}, lr={get_lr(optim):.5f}, Training Loss: {interval_loss:.6f}')
|
||||
f'{dataset_id} {method_name} Epoch: {epoch}, Step: {idx}, lr={get_lr(optim):.6f}, Training Loss: {interval_loss:.6f}')
|
||||
|
||||
mean_loss = np.mean(interval_loss)
|
||||
logfile.add_row(epoch=epoch, measure='tr_loss', value=mean_loss, timelapse=time() - tinit)
|
||||
return mean_loss
|
||||
|
||||
|
||||
def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, measure_prefix):
|
||||
print('# Validating model ...')
|
||||
loss_history = []
|
||||
|
|
@ -181,7 +190,7 @@ def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, mea
|
|||
prediction = predict(logits)
|
||||
loss_history.append(loss)
|
||||
|
||||
# Assigning prediction to dict in predictionS and yte_stacked according to lang_idx
|
||||
# Assigning prediction to dict in predictions and yte_stacked according to lang_idx
|
||||
for i, pred in enumerate(prediction):
|
||||
lang_pred = id_2_lang[lang_idx.numpy()[i]]
|
||||
predictions[lang_pred].append(pred)
|
||||
|
|
@ -208,19 +217,20 @@ def test(model, test_dataloader, lang_ids, tinit, epoch, logfile, criterion, mea
|
|||
|
||||
return Mf1
|
||||
|
||||
|
||||
def get_tr_val_split(l_tokenized_tr, l_devel_target, val_prop, max_val, seed):
|
||||
l_split_va = l_tokenized_tr
|
||||
l_split_va = deepcopy(l_tokenized_tr)
|
||||
l_split_val_target = {l: [] for l in l_tokenized_tr.keys()}
|
||||
l_split_tr = l_tokenized_tr
|
||||
l_split_tr = deepcopy(l_tokenized_tr)
|
||||
l_split_tr_target = {l: [] for l in l_tokenized_tr.keys()}
|
||||
|
||||
for lang in l_tokenized_tr.keys():
|
||||
val_size = int(min(len(l_tokenized_tr[lang]['input_ids']) * val_prop, max_val))
|
||||
|
||||
l_split_tr[lang]['input_ids'], l_split_va[lang]['input_ids'], l_split_tr_target[lang], l_split_val_target[lang] = \
|
||||
train_test_split(l_tokenized_tr[lang]['input_ids'], l_devel_target[lang], test_size=val_size, random_state=seed, shuffle=True)
|
||||
|
||||
return l_split_tr, l_split_tr_target, l_split_va, l_split_val_target
|
||||
return l_split_tr, l_split_tr_target, l_split_va, l_split_val_target
|
||||
|
||||
|
||||
def main():
|
||||
print('Running main ...')
|
||||
|
|
@ -232,7 +242,9 @@ def main():
|
|||
l_devel_raw, l_devel_target, l_test_raw, l_test_target = load_datasets(DATAPATH)
|
||||
l_tokenized_tr = do_tokenization(l_devel_raw, max_len=512)
|
||||
|
||||
l_split_tr, l_split_tr_target, l_split_va, l_split_val_target = get_tr_val_split(l_tokenized_tr, l_devel_target, val_prop=0.2, max_val=2000, seed=opt.seed)
|
||||
l_split_tr, l_split_tr_target, l_split_va, l_split_val_target = get_tr_val_split(l_tokenized_tr, l_devel_target,
|
||||
val_prop=0.2, max_val=2000,
|
||||
seed=opt.seed)
|
||||
|
||||
l_tokenized_te = do_tokenization(l_test_raw, max_len=512)
|
||||
|
||||
|
|
@ -264,10 +276,10 @@ def main():
|
|||
lang_ids = va_dataset.lang_ids
|
||||
for epoch in range(1, opt.nepochs+1):
|
||||
print('# Start Training ...')
|
||||
train(model, tr_dataloader, epoch, criterion, optim, 'TestingBert', tinit, logfile)
|
||||
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile)
|
||||
# lr_scheduler.step(epoch=None) # reduces the learning rate
|
||||
|
||||
# validation
|
||||
# Validation
|
||||
macrof1 = test(model, va_dataloader, lang_ids, tinit, epoch, logfile, criterion, 'va')
|
||||
early_stop(macrof1, epoch)
|
||||
if opt.test_each>0:
|
||||
|
|
@ -279,7 +291,7 @@ def main():
|
|||
if not opt.plotmode:
|
||||
break
|
||||
|
||||
if opt.plotmode==False:
|
||||
if not opt.plotmode:
|
||||
print('-' * 80)
|
||||
print('Training over. Performing final evaluation')
|
||||
|
||||
|
|
@ -288,7 +300,7 @@ def main():
|
|||
if opt.val_epochs>0:
|
||||
print(f'running last {opt.val_epochs} training epochs on the validation set')
|
||||
for val_epoch in range(1, opt.val_epochs + 1):
|
||||
train(model, va_dataloader, epoch+val_epoch, criterion, optim, 'TestingBert', tinit, logfile)
|
||||
train(model, va_dataloader, epoch+val_epoch, criterion, optim, method_name, tinit, logfile)
|
||||
|
||||
# final test
|
||||
print('Training complete: testing')
|
||||
|
|
@ -334,4 +346,4 @@ if __name__ == '__main__':
|
|||
opt.patience = 5
|
||||
|
||||
main()
|
||||
# TODO: refactor .cuda() -> .to(device) in order to check if the process is faster on CPU given the bigger batch size
|
||||
# TODO: refactor .cuda() -> .to(device) in order to check if the process is faster on CPU given the bigger batch size
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
import torch
|
||||
from time import time
|
||||
from util.file import create_if_not_exist
|
||||
|
||||
import warnings
|
||||
|
||||
class EarlyStopping:
|
||||
|
||||
|
|
@ -19,9 +19,10 @@ class EarlyStopping:
|
|||
self.optimizer = optimizer
|
||||
self.STOP = False
|
||||
|
||||
def __call__(self, watch_score, epoch): #model
|
||||
def __call__(self, watch_score, epoch):
|
||||
|
||||
if self.STOP: return #done
|
||||
if self.STOP:
|
||||
return
|
||||
|
||||
if self.best_score is None or watch_score >= self.best_score:
|
||||
self.best_score = watch_score
|
||||
|
|
@ -29,10 +30,12 @@ class EarlyStopping:
|
|||
self.stop_time = time()
|
||||
if self.checkpoint:
|
||||
self.print(f'[early-stop] improved, saving model in {self.checkpoint}')
|
||||
torch.save(self.model, self.checkpoint)
|
||||
# with open(self.checkpoint)
|
||||
# torch.save({'state_dict': self.model.state_dict(),
|
||||
# 'optimizer_state_dict': self.optimizer.state_dict()}, self.checkpoint)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
torch.save(self.model, self.checkpoint)
|
||||
# with open(self.checkpoint)
|
||||
# torch.save({'state_dict': self.model.state_dict(),
|
||||
# 'optimizer_state_dict': self.optimizer.state_dict()}, self.checkpoint)
|
||||
else:
|
||||
self.print(f'[early-stop] improved')
|
||||
self.patience = self.patience_limit
|
||||
|
|
|
|||
Loading…
Reference in New Issue