running comparison

This commit is contained in:
andrea 2021-02-11 12:44:32 +01:00
parent 840293ee17
commit d273691223
2 changed files with 5 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from util.common import *
from util.parser_options import * from util.parser_options import *
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == '__main__': if __name__ == '__main__':
@ -108,8 +108,8 @@ if __name__ == '__main__':
""" """
View generator (-B): generates document embedding via mBERT model. View generator (-B): generates document embedding via mBERT model.
""" """
op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG op.bert_path = '/home/andreapdr/gfun/hug_checkpoint/pytorch_model.bin'
op.bert_path = None # op.bert_path = None
mbert = MBertEmbedder(path_to_model=op.bert_path, mbert = MBertEmbedder(path_to_model=op.bert_path,
nC=data.num_categories(), options=op) nC=data.num_categories(), options=op)
if op.allprob: if op.allprob:

View File

@ -100,7 +100,8 @@ class ExtractorDataset(Dataset):
def get_model(n_out): def get_model(n_out):
print('# Initializing model ...') print('# Initializing model ...')
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=n_out) model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=n_out,
output_hidden_states=True)
return model return model