running comparison with refactored
This commit is contained in:
parent
b98821d3ff
commit
6b68bb01ad
|
@ -193,13 +193,14 @@ class WordClassEmbedder:
|
|||
|
||||
class MBertEmbedder:
|
||||
|
||||
def __init__(self, doc_embed_path=None, patience=10, checkpoint_dir='../hug_checkpoint/', path_to_model=None,
|
||||
nC=None):
|
||||
def __init__(self, options, doc_embed_path=None, patience=10, checkpoint_dir='../hug_checkpoint/', path_to_model=None,
|
||||
nC=None, ):
|
||||
self.doc_embed_path = doc_embed_path
|
||||
self.patience = patience
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.fitted = False
|
||||
self.requires_tfidf = False
|
||||
self.options = options
|
||||
if path_to_model is None and nC is not None:
|
||||
self.model = None
|
||||
else:
|
||||
|
@ -238,12 +239,13 @@ class MBertEmbedder:
|
|||
# Training loop
|
||||
logfile = '../log/log_mBert_extractor.csv'
|
||||
method_name = 'mBert_feature_extractor'
|
||||
logfile = init_logfile_nn(method_name, self.options)
|
||||
|
||||
tinit = time()
|
||||
tinit = time.time()
|
||||
lang_ids = va_dataset.lang_ids
|
||||
for epoch in range(1, nepochs + 1):
|
||||
print('# Start Training ...')
|
||||
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile)
|
||||
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile, self.options)
|
||||
lr_scheduler.step() # reduces the learning rate # TODO arg epoch?
|
||||
|
||||
# Validation
|
||||
|
@ -260,7 +262,7 @@ class MBertEmbedder:
|
|||
if val_epochs > 0:
|
||||
print(f'running last {val_epochs} training epochs on the validation set')
|
||||
for val_epoch in range(1, val_epochs + 1):
|
||||
train(self.model, va_dataloader, epoch + val_epoch, criterion, optim, method_name, tinit, logfile)
|
||||
train(self.model, va_dataloader, epoch + val_epoch, criterion, optim, method_name, tinit, logfile, self.options)
|
||||
|
||||
self.fitted = True
|
||||
return self
|
||||
|
|
|
@ -7,6 +7,10 @@ from util.results import PolylingualClassificationResults
|
|||
from util.common import *
|
||||
from util.parser_options import *
|
||||
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
(op, args) = parser.parse_args()
|
||||
dataset = op.dataset
|
||||
|
@ -14,8 +18,8 @@ if __name__ == '__main__':
|
|||
assert not (op.set_c != 1. and op.optimc), 'Parameter C cannot be defined along with optim_c option'
|
||||
assert op.posteriors or op.supervised or op.pretrained or op.mbert or op.gruViewGenerator, \
|
||||
'empty set of document embeddings is not allowed'
|
||||
assert (op.gruWCE or op.gruMUSE) and op.gruViewGenerator, 'Initializing Gated Recurrent embedding layer without ' \
|
||||
'explicit initialization of GRU View Generator'
|
||||
assert not ((op.gruWCE or op.gruMUSE) and op.gruViewGenerator), 'Initializing Gated Recurrent embedding layer without ' \
|
||||
'explicit initialization of GRU View Generator'
|
||||
|
||||
l2 = op.l2
|
||||
dataset_file = os.path.basename(dataset)
|
||||
|
@ -35,11 +39,20 @@ if __name__ == '__main__':
|
|||
|
||||
# load dataset
|
||||
data = MultilingualDataset.load(dataset)
|
||||
data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
|
||||
# data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
|
||||
data.show_dimensions()
|
||||
lXtr, lytr = data.training()
|
||||
lXte, lyte = data.test()
|
||||
|
||||
# DEBUGGING
|
||||
ratio = 0.01
|
||||
lXtr = {k:v[:50] for k,v in lXtr.items()}
|
||||
lytr = {k: v[:50] for k, v in lytr.items()}
|
||||
lXte = {k: v[:50] for k, v in lXte.items()}
|
||||
lyte = {k: v[:50] for k, v in lyte.items()}
|
||||
|
||||
|
||||
|
||||
# text preprocessing
|
||||
tfidfvectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||
|
||||
|
@ -97,8 +110,9 @@ if __name__ == '__main__':
|
|||
View generator (-B): generates document embedding via mBERT model.
|
||||
"""
|
||||
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
|
||||
op.bert_path = None
|
||||
mbert = MBertEmbedder(path_to_model=op.bert_path,
|
||||
nC=data.num_categories())
|
||||
nC=data.num_categories(), options=op)
|
||||
if op.allprob:
|
||||
mbert = FeatureSet2Posteriors(mbert, l2=l2)
|
||||
doc_embedder.append(mbert)
|
||||
|
|
|
@ -153,10 +153,10 @@ def do_tokenization(l_dataset, max_len=512, verbose=True):
|
|||
return l_tokenized
|
||||
|
||||
|
||||
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile, log_interval=10):
|
||||
# _dataset_path = opt.dataset.split('/')[-1].split('_')
|
||||
# dataset_id = _dataset_path[0] + _dataset_path[-1]
|
||||
dataset_id = 'TODO fix this!'
|
||||
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile, opt, log_interval=10):
|
||||
_dataset_path = opt.dataset.split('/')[-1].split('_')
|
||||
dataset_id = _dataset_path[0] + _dataset_path[-1]
|
||||
# dataset_id = 'TODO fix this!'
|
||||
|
||||
loss_history = []
|
||||
model.train()
|
||||
|
|
Loading…
Reference in New Issue