From faa387f6965aad2f50dd83cae4c60d656a10a310 Mon Sep 17 00:00:00 2001 From: andrea Date: Wed, 20 Jan 2021 14:56:06 +0100 Subject: [PATCH] Implemented custom micro and macro F1 in pl (cpu and gpu) --- refactor/models/pl_gru.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index 690843d..c810220 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -69,8 +69,9 @@ class RecurrentModel(pl.LightningModule): self.linear2 = nn.Linear(ff1, ff2) self.label = nn.Linear(ff2, self.output_size) - lPretrained = None # TODO: setting lPretrained to None, letting it to its original value will bug first - # validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow) + # TODO: setting lPretrained to None, letting it to its original value will bug first validation + # step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow) + lPretrained = None self.save_hyperparameters() def forward(self, lX):