Implemented custom micro and macro F1 in pl (cpu and gpu)
This commit is contained in:
parent
8dbe48ff7a
commit
91666bd263
|
@ -15,7 +15,7 @@ def main(args):
|
||||||
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||||
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||||
data = MultilingualDataset.load(_DATASET)
|
data = MultilingualDataset.load(_DATASET)
|
||||||
data.set_view(languages=['it'], categories=[0, 1])
|
# data.set_view(languages=['it'], categories=[0, 1])
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ def main(args):
|
||||||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=5, nepochs=100,
|
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, nepochs=50,
|
||||||
gpus=args.gpus, n_jobs=N_JOBS)
|
gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
||||||
|
|
||||||
|
|
|
@ -29,10 +29,11 @@ class RecurrentModel(pl.LightningModule):
|
||||||
self.drop_embedding_range = drop_embedding_range
|
self.drop_embedding_range = drop_embedding_range
|
||||||
self.drop_embedding_prop = drop_embedding_prop
|
self.drop_embedding_prop = drop_embedding_prop
|
||||||
self.loss = torch.nn.BCEWithLogitsLoss()
|
self.loss = torch.nn.BCEWithLogitsLoss()
|
||||||
self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro')
|
# self.microf1 = F1(num_classes=output_size, multilabel=True, average='micro')
|
||||||
self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro')
|
# self.macrof1 = F1(num_classes=output_size, multilabel=True, average='macro')
|
||||||
self.accuracy = Accuracy()
|
self.accuracy = Accuracy()
|
||||||
self.customMetrics = CustomF1(num_classes=output_size, device=self.gpus)
|
self.customMicroF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus)
|
||||||
|
self.customMacroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus)
|
||||||
|
|
||||||
self.lPretrained_embeddings = nn.ModuleDict()
|
self.lPretrained_embeddings = nn.ModuleDict()
|
||||||
self.lLearnable_embeddings = nn.ModuleDict()
|
self.lLearnable_embeddings = nn.ModuleDict()
|
||||||
|
@ -103,10 +104,12 @@ class RecurrentModel(pl.LightningModule):
|
||||||
# Squashing logits through Sigmoid in order to get confidence score
|
# Squashing logits through Sigmoid in order to get confidence score
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
accuracy = self.accuracy(predictions, ly)
|
accuracy = self.accuracy(predictions, ly)
|
||||||
custom = self.customMetrics(predictions, ly)
|
microF1 = self.customMicroF1(predictions, ly)
|
||||||
|
macroF1 = self.customMacroF1(predictions, ly)
|
||||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('custom', custom, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
self.log('microF1', microF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
self.log('macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, val_batch, batch_idx):
|
||||||
|
@ -212,4 +215,17 @@ class CustomF1(Metric):
|
||||||
return (num / den).to(self.device)
|
return (num / den).to(self.device)
|
||||||
return torch.FloatTensor([1.]).to(self.device)
|
return torch.FloatTensor([1.]).to(self.device)
|
||||||
if self.average == 'macro':
|
if self.average == 'macro':
|
||||||
raise NotImplementedError
|
class_specific = []
|
||||||
|
for i in range(self.num_classes):
|
||||||
|
class_tp = self.true_positive[i]
|
||||||
|
# class_tn = self.true_negative[i]
|
||||||
|
class_fp = self.false_positive[i]
|
||||||
|
class_fn = self.false_negative[i]
|
||||||
|
num = 2.0 * class_tp
|
||||||
|
den = 2.0 * class_tp + class_fp + class_fn
|
||||||
|
if den > 0:
|
||||||
|
class_specific.append(num / den)
|
||||||
|
else:
|
||||||
|
class_specific.append(1.)
|
||||||
|
average = torch.sum(torch.Tensor(class_specific))/self.num_classes
|
||||||
|
return average.to(self.device)
|
||||||
|
|
Loading…
Reference in New Issue