forked from moreo/QuaPy
bugfix
This commit is contained in:
parent
d197167cfd
commit
8ef9e6a633
|
@ -33,6 +33,9 @@ def quantification_models():
|
||||||
yield 'svmmae', OneVsAll(qp.method.aggregative.SVMAE(args.svmperfpath)), svmperf_params
|
yield 'svmmae', OneVsAll(qp.method.aggregative.SVMAE(args.svmperfpath)), svmperf_params
|
||||||
yield 'svmmrae', OneVsAll(qp.method.aggregative.SVMRAE(args.svmperfpath)), svmperf_params
|
yield 'svmmrae', OneVsAll(qp.method.aggregative.SVMRAE(args.svmperfpath)), svmperf_params
|
||||||
|
|
||||||
|
#sld = qp.method.aggregative.EMQ(newLR())
|
||||||
|
#yield 'paccsld', qp.method.aggregative.PACC(sld), lr_params
|
||||||
|
|
||||||
# 'mlpe': lambda learner: MaximumLikelihoodPrevalenceEstimation(),
|
# 'mlpe': lambda learner: MaximumLikelihoodPrevalenceEstimation(),
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,8 +139,9 @@ if __name__ == '__main__':
|
||||||
print(f'Result folder: {args.results}')
|
print(f'Result folder: {args.results}')
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
optim_losses = ['mae', 'mrae']
|
#optim_losses = ['mae', 'mrae']
|
||||||
datasets = qp.datasets.TWITTER_SENTIMENT_DATASETS_TRAIN
|
optim_losses = ['mae']
|
||||||
|
datasets = ['hcr'] # qp.datasets.TWITTER_SENTIMENT_DATASETS_TRAIN
|
||||||
models = quantification_models()
|
models = quantification_models()
|
||||||
|
|
||||||
results = Parallel(n_jobs=settings.N_JOBS)(
|
results = Parallel(n_jobs=settings.N_JOBS)(
|
||||||
|
|
|
@ -119,7 +119,7 @@ for i, eval_func in enumerate(evaluation_measures):
|
||||||
# ----------------------------------------------------
|
# ----------------------------------------------------
|
||||||
|
|
||||||
eval_name = eval_func.__name__
|
eval_name = eval_func.__name__
|
||||||
added_methods = ['svm' + eval_name] + new_methods
|
added_methods = ['svmm' + eval_name] + new_methods
|
||||||
methods = gao_seb_methods + added_methods
|
methods = gao_seb_methods + added_methods
|
||||||
nold_methods = len(gao_seb_methods)
|
nold_methods = len(gao_seb_methods)
|
||||||
nnew_methods = len(added_methods)
|
nnew_methods = len(added_methods)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from sklearn.calibration import CalibratedClassifierCV
|
||||||
from sklearn.metrics import confusion_matrix
|
from sklearn.metrics import confusion_matrix
|
||||||
from sklearn.model_selection import StratifiedKFold
|
from sklearn.model_selection import StratifiedKFold
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import quapy as qp
|
||||||
import quapy.functional as F
|
import quapy.functional as F
|
||||||
from quapy.classification.svmperf import SVMperf
|
from quapy.classification.svmperf import SVMperf
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
|
@ -69,8 +69,11 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier):
|
||||||
probabilities.
|
probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def posterior_probabilities(self, data):
|
def posterior_probabilities(self, instances):
|
||||||
return self.learner.predict_proba(data)
|
return self.learner.predict_proba(instances)
|
||||||
|
|
||||||
|
def predict_proba(self, instances):
|
||||||
|
return self.posterior_probabilities(instances)
|
||||||
|
|
||||||
def quantify(self, instances):
|
def quantify(self, instances):
|
||||||
classif_posteriors = self.posterior_probabilities(instances)
|
classif_posteriors = self.posterior_probabilities(instances)
|
||||||
|
@ -122,7 +125,11 @@ def training_helper(learner,
|
||||||
'proportion, or a LabelledCollection indicating the validation split')
|
'proportion, or a LabelledCollection indicating the validation split')
|
||||||
else:
|
else:
|
||||||
train, unused = data, None
|
train, unused = data, None
|
||||||
learner.fit(train.instances, train.labels)
|
|
||||||
|
if isinstance(learner, BaseQuantifier):
|
||||||
|
learner.fit(train)
|
||||||
|
else:
|
||||||
|
learner.fit(train.instances, train.labels)
|
||||||
else:
|
else:
|
||||||
if ensure_probabilistic:
|
if ensure_probabilistic:
|
||||||
if not hasattr(learner, 'predict_proba'):
|
if not hasattr(learner, 'predict_proba'):
|
||||||
|
@ -229,10 +236,10 @@ class ACC(AggregativeQuantifier):
|
||||||
|
|
||||||
|
|
||||||
class PCC(AggregativeProbabilisticQuantifier):
|
class PCC(AggregativeProbabilisticQuantifier):
|
||||||
def __init__(self, learner:BaseEstimator):
|
def __init__(self, learner: BaseEstimator):
|
||||||
self.learner = learner
|
self.learner = learner
|
||||||
|
|
||||||
def fit(self, data : LabelledCollection, fit_learner=True):
|
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||||
self.learner, _ = training_helper(self.learner, data, fit_learner, ensure_probabilistic=True)
|
self.learner, _ = training_helper(self.learner, data, fit_learner, ensure_probabilistic=True)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -301,9 +308,6 @@ class PACC(AggregativeProbabilisticQuantifier):
|
||||||
def classify(self, data):
|
def classify(self, data):
|
||||||
return self.pcc.classify(data)
|
return self.pcc.classify(data)
|
||||||
|
|
||||||
def soft_classify(self, data):
|
|
||||||
return self.pcc.posterior_probabilities(data)
|
|
||||||
|
|
||||||
|
|
||||||
class EMQ(AggregativeProbabilisticQuantifier):
|
class EMQ(AggregativeProbabilisticQuantifier):
|
||||||
|
|
||||||
|
@ -319,7 +323,13 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors, epsilon=EPSILON):
|
def aggregate(self, classif_posteriors, epsilon=EPSILON):
|
||||||
return self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||||
|
return priors
|
||||||
|
|
||||||
|
def predict_proba(self, instances, epsilon=EPSILON):
|
||||||
|
classif_posteriors = self.learner.predict_proba(instances)
|
||||||
|
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||||
|
return posteriors
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def EM(cls, tr_prev, posterior_probabilities, epsilon=EPSILON):
|
def EM(cls, tr_prev, posterior_probabilities, epsilon=EPSILON):
|
||||||
|
@ -337,7 +347,7 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
||||||
# M-step: qs_pos is Ps+1(y=+1)
|
# M-step: qs_pos is Ps+1(y=+1)
|
||||||
qs = ps.mean(axis=0)
|
qs = ps.mean(axis=0)
|
||||||
|
|
||||||
if qs_prev_ is not None and error.mae(qs, qs_prev_) < epsilon and s>10:
|
if qs_prev_ is not None and qp.error.mae(qs, qs_prev_) < epsilon and s>10:
|
||||||
converged = True
|
converged = True
|
||||||
|
|
||||||
qs_prev_ = qs
|
qs_prev_ = qs
|
||||||
|
@ -346,7 +356,7 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
||||||
if not converged:
|
if not converged:
|
||||||
raise UserWarning('the method has reached the maximum number of iterations; it might have not converged')
|
raise UserWarning('the method has reached the maximum number of iterations; it might have not converged')
|
||||||
|
|
||||||
return qs
|
return qs, ps
|
||||||
|
|
||||||
|
|
||||||
class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
|
@ -493,7 +503,7 @@ class OneVsAll(AggregativeQuantifier):
|
||||||
return classif_predictions_bin.T
|
return classif_predictions_bin.T
|
||||||
|
|
||||||
def aggregate(self, classif_predictions_bin):
|
def aggregate(self, classif_predictions_bin):
|
||||||
assert set(np.unique(classif_predictions_bin)) == {0,1}, \
|
assert set(np.unique(classif_predictions_bin)).issubset({0,1}), \
|
||||||
'param classif_predictions_bin does not seem to be a valid matrix (ndarray) of binary ' \
|
'param classif_predictions_bin does not seem to be a valid matrix (ndarray) of binary ' \
|
||||||
'predictions for each document (row) and class (columns)'
|
'predictions for each document (row) and class (columns)'
|
||||||
prevalences = self.__parallel(self._delayed_binary_aggregate, classif_predictions_bin)
|
prevalences = self.__parallel(self._delayed_binary_aggregate, classif_predictions_bin)
|
||||||
|
|
Loading…
Reference in New Issue