running comparison
This commit is contained in:
parent
6b68bb01ad
commit
840293ee17
|
@ -209,7 +209,7 @@ class MBertEmbedder:
|
||||||
self.model = BertForSequenceClassification.from_pretrained(path_to_model, config=config).cuda()
|
self.model = BertForSequenceClassification.from_pretrained(path_to_model, config=config).cuda()
|
||||||
self.fitted = True
|
self.fitted = True
|
||||||
|
|
||||||
def fit(self, lX, ly, lV=None, seed=0, nepochs=200, lr=1e-5, val_epochs=1):
|
def fit(self, lX, ly, lV=None, seed=0, nepochs=25, lr=1e-5, val_epochs=1):
|
||||||
print('### mBERT View Generator (B)')
|
print('### mBERT View Generator (B)')
|
||||||
if self.fitted is True:
|
if self.fitted is True:
|
||||||
print('Bert model already fitted!')
|
print('Bert model already fitted!')
|
||||||
|
|
|
@ -45,7 +45,6 @@ if __name__ == '__main__':
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
# DEBUGGING
|
# DEBUGGING
|
||||||
ratio = 0.01
|
|
||||||
lXtr = {k:v[:50] for k,v in lXtr.items()}
|
lXtr = {k:v[:50] for k,v in lXtr.items()}
|
||||||
lytr = {k: v[:50] for k, v in lytr.items()}
|
lytr = {k: v[:50] for k, v in lytr.items()}
|
||||||
lXte = {k: v[:50] for k, v in lXte.items()}
|
lXte = {k: v[:50] for k, v in lXte.items()}
|
||||||
|
|
Loading…
Reference in New Issue