diff --git a/src/models/pl_bert.py b/src/models/pl_bert.py index 129c3b4..1da9c69 100644 --- a/src/models/pl_bert.py +++ b/src/models/pl_bert.py @@ -136,7 +136,7 @@ class BertModel(pl.LightningModule): self.log('test-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True) return - def configure_optimizers(self, lr=3e-5, weight_decay=0.01): + def configure_optimizers(self, lr=1e-5, weight_decay=0.01): no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in self.bert.named_parameters()