1xn2 cont_table output fixed

This commit is contained in:
Lorenzo Volpi 2024-04-08 17:58:34 +02:00
parent 5cfd5d87dd
commit 8a087e3e2f
1 changed files with 54 additions and 10 deletions

View File

@ -4,7 +4,7 @@ from copy import deepcopy
import numpy as np import numpy as np
import quapy.functional as F import quapy.functional as F
import scipy import scipy
from quapy.data.base import LabelledCollection from quapy.data.base import LabelledCollection as LC
from quapy.method.aggregative import AggregativeQuantifier from quapy.method.aggregative import AggregativeQuantifier
from quapy.method.base import BaseQuantifier from quapy.method.base import BaseQuantifier
from scipy.sparse import csr_matrix, issparse from scipy.sparse import csr_matrix, issparse
@ -15,6 +15,49 @@ from quacc.models.base import ClassifierAccuracyPrediction
from quacc.models.utils import get_posteriors_from_h, max_conf, neg_entropy from quacc.models.utils import get_posteriors_from_h, max_conf, neg_entropy
class LabelledCollection(LC):
def empty_classes(self):
"""
Returns a np.ndarray of empty classes (classes present in self.classes_ but with
no positive instance). In case there is none, then an empty np.ndarray is returned
:return: np.ndarray
"""
idx = np.argwhere(self.counts() == 0).flatten()
return self.classes_[idx]
def non_empty_classes(self):
"""
Returns a np.ndarray of non-empty classes (classes present in self.classes_ but with
at least one positive instance). In case there is none, then an empty np.ndarray is returned
:return: np.ndarray
"""
idx = np.argwhere(self.counts() > 0).flatten()
return self.classes_[idx]
def has_empty_classes(self):
"""
Checks whether the collection has empty classes
:return: boolean
"""
return len(self.empty_classes()) > 0
def compact_classes(self):
"""
Generates a new LabelledCollection object with no empty classes. It also returns a np.ndarray of
indexes that correspond to the old indexes of the new self.classes_.
:return: (LabelledCollection, np.ndarray,)
"""
non_empty = self.non_empty_classes()
all_classes = self.classes_
old_pos = np.searchsorted(all_classes, non_empty)
non_empty_collection = LabelledCollection(*self.Xy, classes=non_empty)
return non_empty_collection, old_pos
class CAPContingencyTable(ClassifierAccuracyPrediction): class CAPContingencyTable(ClassifierAccuracyPrediction):
def __init__(self, h: BaseEstimator, acc: callable): def __init__(self, h: BaseEstimator, acc: callable):
self.h = h self.h = h
@ -304,9 +347,9 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
pred_labels = self.h.predict(val.X) pred_labels = self.h.predict(val.X)
true_labels = val.y true_labels = val.y
n = val.n_classes self.ncl = val.n_classes
classes_dot = np.arange(n**2) classes_dot = np.arange(self.ncl**2)
ct_class_idx = classes_dot.reshape(n, n) ct_class_idx = classes_dot.reshape(self.ncl, self.ncl)
X_dot = self._get_X_dot(val.X) X_dot = self._get_X_dot(val.X)
y_dot = ct_class_idx[true_labels, pred_labels] y_dot = ct_class_idx[true_labels, pred_labels]
@ -315,7 +358,8 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
def predict_ct(self, X, oracle_prev=None): def predict_ct(self, X, oracle_prev=None):
X_dot = self._get_X_dot(X) X_dot = self._get_X_dot(X)
return self.q.quantify(X_dot) flat_ct = self.q.quantify(X_dot)
return flat_ct.reshape(self.ncl, self.ncl)
class QuAcc1xNp1(CAPContingencyTableQ, QuAcc): class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
@ -343,11 +387,11 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
pred_labels = self.h.predict(val.X) pred_labels = self.h.predict(val.X)
true_labels = val.y true_labels = val.y
n = val.n_classes self.ncl = val.n_classes
classes_dot = np.arange(n + 1) classes_dot = np.arange(self.ncl + 1)
# ct_class_idx = classes_dot.reshape(n, n) # ct_class_idx = classes_dot.reshape(n, n)
ct_class_idx = np.full((n, n), n) ct_class_idx = np.full((self.ncl, self.ncl), self.ncl)
ct_class_idx[np.diag_indices(n)] = np.arange(n) ct_class_idx[np.diag_indices(self.ncl)] = np.arange(self.ncl)
X_dot = self._get_X_dot(val.X) X_dot = self._get_X_dot(val.X)
y_dot = ct_class_idx[true_labels, pred_labels] y_dot = ct_class_idx[true_labels, pred_labels]
@ -364,7 +408,7 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
def predict_ct(self, X: LabelledCollection, oracle_prev=None): def predict_ct(self, X: LabelledCollection, oracle_prev=None):
X_dot = self._get_X_dot(X) X_dot = self._get_X_dot(X)
ct_compressed = self.q.quantify(X_dot) ct_compressed = self.q.quantify(X_dot)
return self._get_ct_hat(X.n_classes, ct_compressed) return self._get_ct_hat(self.ncl, ct_compressed)
class QuAccNxN(CAPContingencyTableQ, QuAcc): class QuAccNxN(CAPContingencyTableQ, QuAcc):