running comparison

This commit is contained in:
andrea 2021-02-11 11:54:07 +01:00
parent 6b68bb01ad
commit 840293ee17
2 changed files with 1 additions and 2 deletions

View File

@ -209,7 +209,7 @@ class MBertEmbedder:
self.model = BertForSequenceClassification.from_pretrained(path_to_model, config=config).cuda() self.model = BertForSequenceClassification.from_pretrained(path_to_model, config=config).cuda()
self.fitted = True self.fitted = True
def fit(self, lX, ly, lV=None, seed=0, nepochs=200, lr=1e-5, val_epochs=1): def fit(self, lX, ly, lV=None, seed=0, nepochs=25, lr=1e-5, val_epochs=1):
print('### mBERT View Generator (B)') print('### mBERT View Generator (B)')
if self.fitted is True: if self.fitted is True:
print('Bert model already fitted!') print('Bert model already fitted!')

View File

@ -45,7 +45,6 @@ if __name__ == '__main__':
lXte, lyte = data.test() lXte, lyte = data.test()
# DEBUGGING # DEBUGGING
ratio = 0.01
lXtr = {k:v[:50] for k,v in lXtr.items()} lXtr = {k:v[:50] for k,v in lXtr.items()}
lytr = {k: v[:50] for k, v in lytr.items()} lytr = {k: v[:50] for k, v in lytr.items()}
lXte = {k: v[:50] for k, v in lXte.items()} lXte = {k: v[:50] for k, v in lXte.items()}