running comparison with refactored

This commit is contained in:
andrea 2021-02-09 11:53:59 +01:00
parent b98821d3ff
commit 6b68bb01ad
3 changed files with 29 additions and 13 deletions

View File

@ -193,13 +193,14 @@ class WordClassEmbedder:
class MBertEmbedder:
def __init__(self, doc_embed_path=None, patience=10, checkpoint_dir='../hug_checkpoint/', path_to_model=None,
nC=None):
def __init__(self, options, doc_embed_path=None, patience=10, checkpoint_dir='../hug_checkpoint/', path_to_model=None,
nC=None, ):
self.doc_embed_path = doc_embed_path
self.patience = patience
self.checkpoint_dir = checkpoint_dir
self.fitted = False
self.requires_tfidf = False
self.options = options
if path_to_model is None and nC is not None:
self.model = None
else:
@ -238,12 +239,13 @@ class MBertEmbedder:
# Training loop
logfile = '../log/log_mBert_extractor.csv'
method_name = 'mBert_feature_extractor'
logfile = init_logfile_nn(method_name, self.options)
tinit = time()
tinit = time.time()
lang_ids = va_dataset.lang_ids
for epoch in range(1, nepochs + 1):
print('# Start Training ...')
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile)
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile, self.options)
lr_scheduler.step() # reduces the learning rate # TODO arg epoch?
# Validation
@ -260,7 +262,7 @@ class MBertEmbedder:
if val_epochs > 0:
print(f'running last {val_epochs} training epochs on the validation set')
for val_epoch in range(1, val_epochs + 1):
train(self.model, va_dataloader, epoch + val_epoch, criterion, optim, method_name, tinit, logfile)
train(self.model, va_dataloader, epoch + val_epoch, criterion, optim, method_name, tinit, logfile, self.options)
self.fitted = True
return self

View File

@ -7,6 +7,10 @@ from util.results import PolylingualClassificationResults
from util.common import *
from util.parser_options import *
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
if __name__ == '__main__':
(op, args) = parser.parse_args()
dataset = op.dataset
@ -14,8 +18,8 @@ if __name__ == '__main__':
assert not (op.set_c != 1. and op.optimc), 'Parameter C cannot be defined along with optim_c option'
assert op.posteriors or op.supervised or op.pretrained or op.mbert or op.gruViewGenerator, \
'empty set of document embeddings is not allowed'
assert (op.gruWCE or op.gruMUSE) and op.gruViewGenerator, 'Initializing Gated Recurrent embedding layer without ' \
'explicit initialization of GRU View Generator'
assert not ((op.gruWCE or op.gruMUSE) and op.gruViewGenerator), 'Initializing Gated Recurrent embedding layer without ' \
'explicit initialization of GRU View Generator'
l2 = op.l2
dataset_file = os.path.basename(dataset)
@ -35,11 +39,20 @@ if __name__ == '__main__':
# load dataset
data = MultilingualDataset.load(dataset)
data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
# data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
data.show_dimensions()
lXtr, lytr = data.training()
lXte, lyte = data.test()
# DEBUGGING
ratio = 0.01
lXtr = {k:v[:50] for k,v in lXtr.items()}
lytr = {k: v[:50] for k, v in lytr.items()}
lXte = {k: v[:50] for k, v in lXte.items()}
lyte = {k: v[:50] for k, v in lyte.items()}
# text preprocessing
tfidfvectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
@ -97,8 +110,9 @@ if __name__ == '__main__':
View generator (-B): generates document embedding via mBERT model.
"""
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
op.bert_path = None
mbert = MBertEmbedder(path_to_model=op.bert_path,
nC=data.num_categories())
nC=data.num_categories(), options=op)
if op.allprob:
mbert = FeatureSet2Posteriors(mbert, l2=l2)
doc_embedder.append(mbert)

View File

@ -153,10 +153,10 @@ def do_tokenization(l_dataset, max_len=512, verbose=True):
return l_tokenized
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile, log_interval=10):
# _dataset_path = opt.dataset.split('/')[-1].split('_')
# dataset_id = _dataset_path[0] + _dataset_path[-1]
dataset_id = 'TODO fix this!'
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile, opt, log_interval=10):
_dataset_path = opt.dataset.split('/')[-1].split('_')
dataset_id = _dataset_path[0] + _dataset_path[-1]
# dataset_id = 'TODO fix this!'
loss_history = []
model.train()