running comparison with refactored
This commit is contained in:
parent
b98821d3ff
commit
6b68bb01ad
|
|
@ -193,13 +193,14 @@ class WordClassEmbedder:
|
||||||
|
|
||||||
class MBertEmbedder:
|
class MBertEmbedder:
|
||||||
|
|
||||||
def __init__(self, doc_embed_path=None, patience=10, checkpoint_dir='../hug_checkpoint/', path_to_model=None,
|
def __init__(self, options, doc_embed_path=None, patience=10, checkpoint_dir='../hug_checkpoint/', path_to_model=None,
|
||||||
nC=None):
|
nC=None, ):
|
||||||
self.doc_embed_path = doc_embed_path
|
self.doc_embed_path = doc_embed_path
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
self.requires_tfidf = False
|
self.requires_tfidf = False
|
||||||
|
self.options = options
|
||||||
if path_to_model is None and nC is not None:
|
if path_to_model is None and nC is not None:
|
||||||
self.model = None
|
self.model = None
|
||||||
else:
|
else:
|
||||||
|
|
@ -238,12 +239,13 @@ class MBertEmbedder:
|
||||||
# Training loop
|
# Training loop
|
||||||
logfile = '../log/log_mBert_extractor.csv'
|
logfile = '../log/log_mBert_extractor.csv'
|
||||||
method_name = 'mBert_feature_extractor'
|
method_name = 'mBert_feature_extractor'
|
||||||
|
logfile = init_logfile_nn(method_name, self.options)
|
||||||
|
|
||||||
tinit = time()
|
tinit = time.time()
|
||||||
lang_ids = va_dataset.lang_ids
|
lang_ids = va_dataset.lang_ids
|
||||||
for epoch in range(1, nepochs + 1):
|
for epoch in range(1, nepochs + 1):
|
||||||
print('# Start Training ...')
|
print('# Start Training ...')
|
||||||
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile)
|
train(model, tr_dataloader, epoch, criterion, optim, method_name, tinit, logfile, self.options)
|
||||||
lr_scheduler.step() # reduces the learning rate # TODO arg epoch?
|
lr_scheduler.step() # reduces the learning rate # TODO arg epoch?
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
|
|
@ -260,7 +262,7 @@ class MBertEmbedder:
|
||||||
if val_epochs > 0:
|
if val_epochs > 0:
|
||||||
print(f'running last {val_epochs} training epochs on the validation set')
|
print(f'running last {val_epochs} training epochs on the validation set')
|
||||||
for val_epoch in range(1, val_epochs + 1):
|
for val_epoch in range(1, val_epochs + 1):
|
||||||
train(self.model, va_dataloader, epoch + val_epoch, criterion, optim, method_name, tinit, logfile)
|
train(self.model, va_dataloader, epoch + val_epoch, criterion, optim, method_name, tinit, logfile, self.options)
|
||||||
|
|
||||||
self.fitted = True
|
self.fitted = True
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,10 @@ from util.results import PolylingualClassificationResults
|
||||||
from util.common import *
|
from util.common import *
|
||||||
from util.parser_options import *
|
from util.parser_options import *
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
(op, args) = parser.parse_args()
|
(op, args) = parser.parse_args()
|
||||||
dataset = op.dataset
|
dataset = op.dataset
|
||||||
|
|
@ -14,8 +18,8 @@ if __name__ == '__main__':
|
||||||
assert not (op.set_c != 1. and op.optimc), 'Parameter C cannot be defined along with optim_c option'
|
assert not (op.set_c != 1. and op.optimc), 'Parameter C cannot be defined along with optim_c option'
|
||||||
assert op.posteriors or op.supervised or op.pretrained or op.mbert or op.gruViewGenerator, \
|
assert op.posteriors or op.supervised or op.pretrained or op.mbert or op.gruViewGenerator, \
|
||||||
'empty set of document embeddings is not allowed'
|
'empty set of document embeddings is not allowed'
|
||||||
assert (op.gruWCE or op.gruMUSE) and op.gruViewGenerator, 'Initializing Gated Recurrent embedding layer without ' \
|
assert not ((op.gruWCE or op.gruMUSE) and op.gruViewGenerator), 'Initializing Gated Recurrent embedding layer without ' \
|
||||||
'explicit initialization of GRU View Generator'
|
'explicit initialization of GRU View Generator'
|
||||||
|
|
||||||
l2 = op.l2
|
l2 = op.l2
|
||||||
dataset_file = os.path.basename(dataset)
|
dataset_file = os.path.basename(dataset)
|
||||||
|
|
@ -35,11 +39,20 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
data = MultilingualDataset.load(dataset)
|
data = MultilingualDataset.load(dataset)
|
||||||
data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
|
# data.set_view(languages=['nl', 'it']) # TODO: DEBUG SETTING
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
lXtr, lytr = data.training()
|
lXtr, lytr = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
|
# DEBUGGING
|
||||||
|
ratio = 0.01
|
||||||
|
lXtr = {k:v[:50] for k,v in lXtr.items()}
|
||||||
|
lytr = {k: v[:50] for k, v in lytr.items()}
|
||||||
|
lXte = {k: v[:50] for k, v in lXte.items()}
|
||||||
|
lyte = {k: v[:50] for k, v in lyte.items()}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# text preprocessing
|
# text preprocessing
|
||||||
tfidfvectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
tfidfvectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
|
||||||
|
|
@ -97,8 +110,9 @@ if __name__ == '__main__':
|
||||||
View generator (-B): generates document embedding via mBERT model.
|
View generator (-B): generates document embedding via mBERT model.
|
||||||
"""
|
"""
|
||||||
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
|
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG
|
||||||
|
op.bert_path = None
|
||||||
mbert = MBertEmbedder(path_to_model=op.bert_path,
|
mbert = MBertEmbedder(path_to_model=op.bert_path,
|
||||||
nC=data.num_categories())
|
nC=data.num_categories(), options=op)
|
||||||
if op.allprob:
|
if op.allprob:
|
||||||
mbert = FeatureSet2Posteriors(mbert, l2=l2)
|
mbert = FeatureSet2Posteriors(mbert, l2=l2)
|
||||||
doc_embedder.append(mbert)
|
doc_embedder.append(mbert)
|
||||||
|
|
|
||||||
|
|
@ -153,10 +153,10 @@ def do_tokenization(l_dataset, max_len=512, verbose=True):
|
||||||
return l_tokenized
|
return l_tokenized
|
||||||
|
|
||||||
|
|
||||||
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile, log_interval=10):
|
def train(model, train_dataloader, epoch, criterion, optim, method_name, tinit, logfile, opt, log_interval=10):
|
||||||
# _dataset_path = opt.dataset.split('/')[-1].split('_')
|
_dataset_path = opt.dataset.split('/')[-1].split('_')
|
||||||
# dataset_id = _dataset_path[0] + _dataset_path[-1]
|
dataset_id = _dataset_path[0] + _dataset_path[-1]
|
||||||
dataset_id = 'TODO fix this!'
|
# dataset_id = 'TODO fix this!'
|
||||||
|
|
||||||
loss_history = []
|
loss_history = []
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue