1xn2 cont_table output fixed
This commit is contained in:
parent
5cfd5d87dd
commit
8a087e3e2f
|
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import quapy.functional as F
|
import quapy.functional as F
|
||||||
import scipy
|
import scipy
|
||||||
from quapy.data.base import LabelledCollection
|
from quapy.data.base import LabelledCollection as LC
|
||||||
from quapy.method.aggregative import AggregativeQuantifier
|
from quapy.method.aggregative import AggregativeQuantifier
|
||||||
from quapy.method.base import BaseQuantifier
|
from quapy.method.base import BaseQuantifier
|
||||||
from scipy.sparse import csr_matrix, issparse
|
from scipy.sparse import csr_matrix, issparse
|
||||||
|
|
@ -15,6 +15,49 @@ from quacc.models.base import ClassifierAccuracyPrediction
|
||||||
from quacc.models.utils import get_posteriors_from_h, max_conf, neg_entropy
|
from quacc.models.utils import get_posteriors_from_h, max_conf, neg_entropy
|
||||||
|
|
||||||
|
|
||||||
|
class LabelledCollection(LC):
|
||||||
|
def empty_classes(self):
|
||||||
|
"""
|
||||||
|
Returns a np.ndarray of empty classes (classes present in self.classes_ but with
|
||||||
|
no positive instance). In case there is none, then an empty np.ndarray is returned
|
||||||
|
|
||||||
|
:return: np.ndarray
|
||||||
|
"""
|
||||||
|
idx = np.argwhere(self.counts() == 0).flatten()
|
||||||
|
return self.classes_[idx]
|
||||||
|
|
||||||
|
def non_empty_classes(self):
|
||||||
|
"""
|
||||||
|
Returns a np.ndarray of non-empty classes (classes present in self.classes_ but with
|
||||||
|
at least one positive instance). In case there is none, then an empty np.ndarray is returned
|
||||||
|
|
||||||
|
:return: np.ndarray
|
||||||
|
"""
|
||||||
|
idx = np.argwhere(self.counts() > 0).flatten()
|
||||||
|
return self.classes_[idx]
|
||||||
|
|
||||||
|
def has_empty_classes(self):
|
||||||
|
"""
|
||||||
|
Checks whether the collection has empty classes
|
||||||
|
|
||||||
|
:return: boolean
|
||||||
|
"""
|
||||||
|
return len(self.empty_classes()) > 0
|
||||||
|
|
||||||
|
def compact_classes(self):
|
||||||
|
"""
|
||||||
|
Generates a new LabelledCollection object with no empty classes. It also returns a np.ndarray of
|
||||||
|
indexes that correspond to the old indexes of the new self.classes_.
|
||||||
|
|
||||||
|
:return: (LabelledCollection, np.ndarray,)
|
||||||
|
"""
|
||||||
|
non_empty = self.non_empty_classes()
|
||||||
|
all_classes = self.classes_
|
||||||
|
old_pos = np.searchsorted(all_classes, non_empty)
|
||||||
|
non_empty_collection = LabelledCollection(*self.Xy, classes=non_empty)
|
||||||
|
return non_empty_collection, old_pos
|
||||||
|
|
||||||
|
|
||||||
class CAPContingencyTable(ClassifierAccuracyPrediction):
|
class CAPContingencyTable(ClassifierAccuracyPrediction):
|
||||||
def __init__(self, h: BaseEstimator, acc: callable):
|
def __init__(self, h: BaseEstimator, acc: callable):
|
||||||
self.h = h
|
self.h = h
|
||||||
|
|
@ -304,9 +347,9 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
|
||||||
pred_labels = self.h.predict(val.X)
|
pred_labels = self.h.predict(val.X)
|
||||||
true_labels = val.y
|
true_labels = val.y
|
||||||
|
|
||||||
n = val.n_classes
|
self.ncl = val.n_classes
|
||||||
classes_dot = np.arange(n**2)
|
classes_dot = np.arange(self.ncl**2)
|
||||||
ct_class_idx = classes_dot.reshape(n, n)
|
ct_class_idx = classes_dot.reshape(self.ncl, self.ncl)
|
||||||
|
|
||||||
X_dot = self._get_X_dot(val.X)
|
X_dot = self._get_X_dot(val.X)
|
||||||
y_dot = ct_class_idx[true_labels, pred_labels]
|
y_dot = ct_class_idx[true_labels, pred_labels]
|
||||||
|
|
@ -315,7 +358,8 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
|
||||||
|
|
||||||
def predict_ct(self, X, oracle_prev=None):
|
def predict_ct(self, X, oracle_prev=None):
|
||||||
X_dot = self._get_X_dot(X)
|
X_dot = self._get_X_dot(X)
|
||||||
return self.q.quantify(X_dot)
|
flat_ct = self.q.quantify(X_dot)
|
||||||
|
return flat_ct.reshape(self.ncl, self.ncl)
|
||||||
|
|
||||||
|
|
||||||
class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
|
class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
|
||||||
|
|
@ -343,11 +387,11 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
|
||||||
pred_labels = self.h.predict(val.X)
|
pred_labels = self.h.predict(val.X)
|
||||||
true_labels = val.y
|
true_labels = val.y
|
||||||
|
|
||||||
n = val.n_classes
|
self.ncl = val.n_classes
|
||||||
classes_dot = np.arange(n + 1)
|
classes_dot = np.arange(self.ncl + 1)
|
||||||
# ct_class_idx = classes_dot.reshape(n, n)
|
# ct_class_idx = classes_dot.reshape(n, n)
|
||||||
ct_class_idx = np.full((n, n), n)
|
ct_class_idx = np.full((self.ncl, self.ncl), self.ncl)
|
||||||
ct_class_idx[np.diag_indices(n)] = np.arange(n)
|
ct_class_idx[np.diag_indices(self.ncl)] = np.arange(self.ncl)
|
||||||
|
|
||||||
X_dot = self._get_X_dot(val.X)
|
X_dot = self._get_X_dot(val.X)
|
||||||
y_dot = ct_class_idx[true_labels, pred_labels]
|
y_dot = ct_class_idx[true_labels, pred_labels]
|
||||||
|
|
@ -364,7 +408,7 @@ class QuAcc1xNp1(CAPContingencyTableQ, QuAcc):
|
||||||
def predict_ct(self, X: LabelledCollection, oracle_prev=None):
|
def predict_ct(self, X: LabelledCollection, oracle_prev=None):
|
||||||
X_dot = self._get_X_dot(X)
|
X_dot = self._get_X_dot(X)
|
||||||
ct_compressed = self.q.quantify(X_dot)
|
ct_compressed = self.q.quantify(X_dot)
|
||||||
return self._get_ct_hat(X.n_classes, ct_compressed)
|
return self._get_ct_hat(self.ncl, ct_compressed)
|
||||||
|
|
||||||
|
|
||||||
class QuAccNxN(CAPContingencyTableQ, QuAcc):
|
class QuAccNxN(CAPContingencyTableQ, QuAcc):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue