running comparison
This commit is contained in:
parent
612e90a584
commit
421d7660f6
12
main.py
12
main.py
|
|
@ -8,7 +8,7 @@ from src.util.results_csv import CSVlog
|
||||||
from src.view_generators import *
|
from src.view_generators import *
|
||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
|
@ -25,6 +25,14 @@ def main(args):
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
|
# TODO: debug settings
|
||||||
|
print(f'\n[Running on DEBUG mode - samples per language are reduced to 50 max!]\n')
|
||||||
|
lX = {k: v[:50] for k, v in lX.items()}
|
||||||
|
ly = {k: v[:50] for k, v in ly.items()}
|
||||||
|
lXte = {k: v[:50] for k, v in lXte.items()}
|
||||||
|
lyte = {k: v[:50] for k, v in lyte.items()}
|
||||||
|
|
||||||
|
|
||||||
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
||||||
if args.gru_embedder or args.bert_embedder:
|
if args.gru_embedder or args.bert_embedder:
|
||||||
multilingualIndex = MultilingualIndex()
|
multilingualIndex = MultilingualIndex()
|
||||||
|
|
@ -101,6 +109,8 @@ def main(args):
|
||||||
time_tr = round(time.time() - time_init, 3)
|
time_tr = round(time.time() - time_init, 3)
|
||||||
print(f'Training completed in {time_tr} seconds!')
|
print(f'Training completed in {time_tr} seconds!')
|
||||||
|
|
||||||
|
exit('[Exiting DEBUG session without testing overall architecture!]')
|
||||||
|
|
||||||
# Testing ----------------------------------------
|
# Testing ----------------------------------------
|
||||||
print('\n[Testing Generalized Funnelling]')
|
print('\n[Testing Generalized Funnelling]')
|
||||||
time_te = time.time()
|
time_te = time.time()
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,6 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from transformers import BertForSequenceClassification, AdamW
|
from transformers import BertForSequenceClassification, AdamW
|
||||||
import numpy as np
|
|
||||||
import csv
|
|
||||||
|
|
||||||
from src.util.common import define_pad_length, pad
|
from src.util.common import define_pad_length, pad
|
||||||
from src.util.pl_metrics import CustomF1, CustomK
|
from src.util.pl_metrics import CustomF1, CustomK
|
||||||
|
|
@ -11,7 +9,7 @@ from src.util.pl_metrics import CustomF1, CustomK
|
||||||
|
|
||||||
class BertModel(pl.LightningModule):
|
class BertModel(pl.LightningModule):
|
||||||
|
|
||||||
def __init__(self, output_size, stored_path, gpus=None, manual_log=False):
|
def __init__(self, output_size, stored_path, gpus=None):
|
||||||
"""
|
"""
|
||||||
Init Bert model.
|
Init Bert model.
|
||||||
:param output_size:
|
:param output_size:
|
||||||
|
|
@ -41,17 +39,6 @@ class BertModel(pl.LightningModule):
|
||||||
output_hidden_states=True)
|
output_hidden_states=True)
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
# Manual logging settings
|
|
||||||
self.manual_log = manual_log
|
|
||||||
if self.manual_log:
|
|
||||||
from src.util.file import create_if_not_exist
|
|
||||||
self.csv_file = f'csv_logs/bert/bert_manual_log_v{self._version}.csv'
|
|
||||||
with open(self.csv_file, 'x') as handler:
|
|
||||||
writer = csv.writer(handler, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
|
||||||
writer.writerow(['tr_loss', 'va_loss', 'va_macroF1', 'va_microF1', 'va_macroK', 'va_microK'])
|
|
||||||
self.csv_metrics = {'tr_loss': [], 'va_loss': [], 'va_macroF1': [],
|
|
||||||
'va_microF1': [], 'va_macroK': [], 'va_microK': []}
|
|
||||||
|
|
||||||
def forward(self, X):
|
def forward(self, X):
|
||||||
logits = self.bert(X)
|
logits = self.bert(X)
|
||||||
return logits
|
return logits
|
||||||
|
|
@ -72,10 +59,14 @@ class BertModel(pl.LightningModule):
|
||||||
self.log('train-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-macroK', macroK, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-macroK', macroK, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-microK', microK, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-microK', microK, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
lX, ly = self._reconstruct_dict(predictions, y, batch_langs)
|
return {'loss': loss}
|
||||||
return {'loss': loss, 'pred': lX, 'target': ly}
|
# lX, ly = self._reconstruct_dict(predictions, y, batch_langs)
|
||||||
|
# return {'loss': loss, 'pred': lX, 'target': ly}
|
||||||
|
|
||||||
|
"""
|
||||||
def training_epoch_end(self, outputs):
|
def training_epoch_end(self, outputs):
|
||||||
|
pass
|
||||||
|
|
||||||
langs = []
|
langs = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
langs.extend(list(output['pred'].keys()))
|
langs.extend(list(output['pred'].keys()))
|
||||||
|
|
@ -114,6 +105,7 @@ class BertModel(pl.LightningModule):
|
||||||
tr_epoch_loss = np.average([out['loss'].item() for out in outputs])
|
tr_epoch_loss = np.average([out['loss'].item() for out in outputs])
|
||||||
self.csv_metrics['tr_loss'].append(tr_epoch_loss)
|
self.csv_metrics['tr_loss'].append(tr_epoch_loss)
|
||||||
self.save_manual_logs()
|
self.save_manual_logs()
|
||||||
|
"""
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, val_batch, batch_idx):
|
||||||
X, y, batch_langs = val_batch
|
X, y, batch_langs = val_batch
|
||||||
|
|
@ -130,16 +122,23 @@ class BertModel(pl.LightningModule):
|
||||||
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('val-microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log('val-macroK', macroK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('val-macroK', macroK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
self.log('val-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
self.log('val-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
|
||||||
if self.manual_log:
|
|
||||||
# Manual logging to csv
|
|
||||||
self.csv_metrics['va_loss'].append(loss.item())
|
|
||||||
self.csv_metrics['va_macroF1'].append(macroF1.item())
|
|
||||||
self.csv_metrics['va_microF1'].append(microF1.item())
|
|
||||||
self.csv_metrics['va_macroK'].append(macroK.item())
|
|
||||||
self.csv_metrics['va_microK'].append(microK.item())
|
|
||||||
|
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
# return {'loss': loss, 'pred': predictions, 'target': y}
|
||||||
|
|
||||||
|
# def validation_epoch_end(self, outputs):
|
||||||
|
# all_pred = []
|
||||||
|
# all_tar = []
|
||||||
|
# for output in outputs:
|
||||||
|
# all_pred.append(output['pred'].cpu().numpy())
|
||||||
|
# all_tar.append(output['target'].cpu().numpy())
|
||||||
|
# all_pred = np.vstack(all_pred)
|
||||||
|
# all_tar = np.vstack(all_tar)
|
||||||
|
# all_pred = {'all': all_pred}
|
||||||
|
# all_tar = {'all': all_tar}
|
||||||
|
# res = evaluate(all_tar, all_pred)
|
||||||
|
# res = [elem for elem in res.values()]
|
||||||
|
# res = np.average(res, axis=0)
|
||||||
|
# print(f'\n{res}')
|
||||||
|
|
||||||
def test_step(self, test_batch, batch_idx):
|
def test_step(self, test_batch, batch_idx):
|
||||||
X, y, batch_langs = test_batch
|
X, y, batch_langs = test_batch
|
||||||
|
|
@ -209,14 +208,3 @@ class BertModel(pl.LightningModule):
|
||||||
for k, v in reconstructed_y.items():
|
for k, v in reconstructed_y.items():
|
||||||
reconstructed_y[k] = torch.cat(v).view(-1, predictions.shape[1])
|
reconstructed_y[k] = torch.cat(v).view(-1, predictions.shape[1])
|
||||||
return reconstructed_x, reconstructed_y
|
return reconstructed_x, reconstructed_y
|
||||||
|
|
||||||
def save_manual_logs(self):
|
|
||||||
if self.global_step == 0:
|
|
||||||
return
|
|
||||||
with open(self.csv_file, 'a', newline='\n') as handler:
|
|
||||||
writer = csv.writer(handler, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
|
||||||
writer.writerow([np.average(self.csv_metrics['tr_loss']), np.average(self.csv_metrics['va_loss']),
|
|
||||||
np.average(self.csv_metrics['va_macroF1']), np.average(self.csv_metrics['va_microF1']),
|
|
||||||
np.average(self.csv_metrics['va_macroK']), np.average(self.csv_metrics['va_microK'])])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,7 @@ def hard_single_metric_statistics(true_labels, predicted_labels):
|
||||||
|
|
||||||
def macro_average(true_labels, predicted_labels, metric, metric_statistics=hard_single_metric_statistics):
|
def macro_average(true_labels, predicted_labels, metric, metric_statistics=hard_single_metric_statistics):
|
||||||
true_labels, predicted_labels, nC = __check_consistency_and_adapt(true_labels, predicted_labels)
|
true_labels, predicted_labels, nC = __check_consistency_and_adapt(true_labels, predicted_labels)
|
||||||
|
_tmp = [metric(metric_statistics(true_labels[:, c], predicted_labels[:, c])) for c in range(nC)]
|
||||||
return np.mean([metric(metric_statistics(true_labels[:, c], predicted_labels[:, c])) for c in range(nC)])
|
return np.mean([metric(metric_statistics(true_labels[:, c], predicted_labels[:, c])) for c in range(nC)])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class CustomF1(Metric):
|
||||||
if den > 0:
|
if den > 0:
|
||||||
class_specific.append(num / den)
|
class_specific.append(num / den)
|
||||||
else:
|
else:
|
||||||
class_specific.append(1.)
|
class_specific.append(torch.FloatTensor([1.]))
|
||||||
average = torch.sum(torch.Tensor(class_specific))/self.num_classes
|
average = torch.sum(torch.Tensor(class_specific))/self.num_classes
|
||||||
return average.to(self.device)
|
return average.to(self.device)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,12 @@ This module contains the view generators that take care of computing the view sp
|
||||||
|
|
||||||
- View generator (-b): generates document embedding via mBERT model.
|
- View generator (-b): generates document embedding via mBERT model.
|
||||||
"""
|
"""
|
||||||
|
import torch
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import src.util.disable_sklearn_warnings
|
|
||||||
# from time import time
|
# from time import time
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
|
||||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||||
|
|
||||||
|
|
@ -448,9 +448,13 @@ class BertGen(ViewGen):
|
||||||
self.stored_path = stored_path
|
self.stored_path = stored_path
|
||||||
self.model = self._init_model()
|
self.model = self._init_model()
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
|
# self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
|
||||||
|
self.logger = CSVLogger(save_dir='csv_logs', name='bert')
|
||||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
patience=self.patience, verbose=True, mode='max')
|
patience=self.patience, verbose=False, mode='max')
|
||||||
|
|
||||||
|
# modifying EarlyStopping global var in order to compute >= with respect to the best score
|
||||||
|
self.early_stop_callback.mode_dict['max'] = torch.ge
|
||||||
|
|
||||||
# Zero shot parameters
|
# Zero shot parameters
|
||||||
self.zero_shot = zero_shot
|
self.zero_shot = zero_shot
|
||||||
|
|
@ -476,7 +480,7 @@ class BertGen(ViewGen):
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs,
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs,
|
||||||
debug=True, max_samples=50) # todo: debug=True -> DEBUG setting
|
debug=False, max_samples=50)
|
||||||
|
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue