forked from moreo/QuaPy
refactoring, chain-classifiers, speeding up for aggregative methods, evaluation modularized
This commit is contained in:
parent
a4fea89122
commit
7b8e6462ff
|
@ -2,11 +2,11 @@ import os,sys
|
||||||
from sklearn.datasets import get_data_home, fetch_20newsgroups
|
from sklearn.datasets import get_data_home, fetch_20newsgroups
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer
|
from sklearn.preprocessing import MultiLabelBinarizer
|
||||||
from jrcacquis_reader import fetch_jrcacquis, JRCAcquis_Document
|
from MultiLabel.data.jrcacquis_reader import fetch_jrcacquis
|
||||||
from ohsumed_reader import fetch_ohsumed50k
|
from MultiLabel.data.ohsumed_reader import fetch_ohsumed50k
|
||||||
from reuters21578_reader import fetch_reuters21578
|
from MultiLabel.data.reuters21578_reader import fetch_reuters21578
|
||||||
from rcv_reader import fetch_RCV1
|
from MultiLabel.data.rcv_reader import fetch_RCV1
|
||||||
from wipo_reader import fetch_WIPOgamma, WipoGammaDocument
|
from MultiLabel.data.wipo_reader import fetch_WIPOgamma, WipoGammaDocument
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.multiclass import OneVsRestClassifier
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
|
||||||
|
class MultilabelStackedClassifier: # aka Funnelling Monolingual
|
||||||
|
def __init__(self, base_estimator=LogisticRegression()):
|
||||||
|
if not hasattr(base_estimator, 'predict_proba'):
|
||||||
|
print('the estimator does not seem to be probabilistic: calibrating')
|
||||||
|
base_estimator = CalibratedClassifierCV(base_estimator)
|
||||||
|
self.base = deepcopy(OneVsRestClassifier(base_estimator))
|
||||||
|
self.meta = deepcopy(OneVsRestClassifier(base_estimator))
|
||||||
|
self.norm = StandardScaler()
|
||||||
|
|
||||||
|
def fit(self, X, y):
|
||||||
|
assert y.ndim==2, 'the dataset does not seem to be multi-label'
|
||||||
|
self.base.fit(X, y)
|
||||||
|
P = self.base.predict_proba(X)
|
||||||
|
P = self.norm.fit_transform(P)
|
||||||
|
self.meta.fit(P, y)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
P = self.base.predict_proba(X)
|
||||||
|
P = self.norm.transform(P)
|
||||||
|
return self.meta.predict(P)
|
||||||
|
|
||||||
|
def predict_proba(self, X):
|
||||||
|
P = self.base.predict_proba(X)
|
||||||
|
P = self.norm.transform(P)
|
||||||
|
return self.meta.predict_proba(P)
|
|
@ -0,0 +1,96 @@
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
from quapy.data import LabelledCollection
|
||||||
|
from quapy.functional import artificial_prevalence_sampling
|
||||||
|
|
||||||
|
|
||||||
|
class MultilabelledCollection:
|
||||||
|
def __init__(self, instances, labels):
|
||||||
|
assert labels.ndim==2, 'data does not seem to be multilabel'
|
||||||
|
self.instances = instances
|
||||||
|
self.labels = labels
|
||||||
|
self.classes_ = np.arange(labels.shape[1])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: str, loader_func: callable):
|
||||||
|
return MultilabelledCollection(*loader_func(path))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.instances.shape[0]
|
||||||
|
|
||||||
|
def prevalence(self):
|
||||||
|
# return self.labels.mean(axis=0)
|
||||||
|
pos = self.labels.mean(axis=0)
|
||||||
|
neg = 1-pos
|
||||||
|
return np.asarray([neg, pos]).T
|
||||||
|
|
||||||
|
def counts(self):
|
||||||
|
return self.labels.sum(axis=0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_classes(self):
|
||||||
|
return len(self.classes_)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def binary(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __gen_index(self):
|
||||||
|
return np.arange(len(self))
|
||||||
|
|
||||||
|
def sampling_multi_index(self, size, cat, prev=None):
|
||||||
|
if prev is None: # no prevalence was indicated; returns an index for uniform sampling
|
||||||
|
return np.random.choice(len(self), size, replace=size>len(self))
|
||||||
|
aux = LabelledCollection(self.__gen_index(), self.labels[:,cat])
|
||||||
|
return aux.sampling_index(size, *[1-prev, prev])
|
||||||
|
|
||||||
|
def uniform_sampling_multi_index(self, size):
|
||||||
|
return np.random.choice(len(self), size, replace=size>len(self))
|
||||||
|
|
||||||
|
def uniform_sampling(self, size):
|
||||||
|
unif_index = self.uniform_sampling_multi_index(size)
|
||||||
|
return self.sampling_from_index(unif_index)
|
||||||
|
|
||||||
|
def sampling(self, size, category, prev=None):
|
||||||
|
prev_index = self.sampling_multi_index(size, category, prev)
|
||||||
|
return self.sampling_from_index(prev_index)
|
||||||
|
|
||||||
|
def sampling_from_index(self, index):
|
||||||
|
documents = self.instances[index]
|
||||||
|
labels = self.labels[index]
|
||||||
|
return MultilabelledCollection(documents, labels)
|
||||||
|
|
||||||
|
def train_test_split(self, train_prop=0.6, random_state=None):
|
||||||
|
tr_docs, te_docs, tr_labels, te_labels = \
|
||||||
|
train_test_split(self.instances, self.labels, train_size=train_prop, random_state=random_state)
|
||||||
|
return MultilabelledCollection(tr_docs, tr_labels), MultilabelledCollection(te_docs, te_labels)
|
||||||
|
|
||||||
|
def artificial_sampling_generator(self, sample_size, category, n_prevalences=101, repeats=1):
|
||||||
|
dimensions = 2
|
||||||
|
for prevs in artificial_prevalence_sampling(dimensions, n_prevalences, repeats).flatten():
|
||||||
|
yield self.sampling(sample_size, category, prevs)
|
||||||
|
|
||||||
|
def artificial_sampling_index_generator(self, sample_size, category, n_prevalences=101, repeats=1):
|
||||||
|
dimensions = 2
|
||||||
|
for prevs in artificial_prevalence_sampling(dimensions, n_prevalences, repeats).flatten():
|
||||||
|
yield self.sampling_multi_index(sample_size, category, prevs)
|
||||||
|
|
||||||
|
def natural_sampling_generator(self, sample_size, repeats=100):
|
||||||
|
for _ in range(repeats):
|
||||||
|
yield self.uniform_sampling(sample_size)
|
||||||
|
|
||||||
|
def natural_sampling_index_generator(self, sample_size, repeats=100):
|
||||||
|
for _ in range(repeats):
|
||||||
|
yield self.uniform_sampling_multi_index(sample_size)
|
||||||
|
|
||||||
|
def asLabelledCollection(self, category):
|
||||||
|
return LabelledCollection(self.instances, self.labels[:,category])
|
||||||
|
|
||||||
|
def genLabelledCollections(self):
|
||||||
|
for c in self.classes_:
|
||||||
|
yield self.asLabelledCollection(c)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def Xy(self):
|
||||||
|
return self.instances, self.labels
|
|
@ -0,0 +1,85 @@
|
||||||
|
from typing import Union, Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import quapy as qp
|
||||||
|
from MultiLabel.mlquantification import MLAggregativeQuantifier
|
||||||
|
from mldata import MultilabelledCollection
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def ml_natural_prevalence_evaluation(model,
|
||||||
|
test:MultilabelledCollection,
|
||||||
|
sample_size,
|
||||||
|
repeats=100,
|
||||||
|
error_metric:Union[str,Callable]='mae',
|
||||||
|
random_seed=42):
|
||||||
|
|
||||||
|
if isinstance(error_metric, str):
|
||||||
|
error_metric = qp.error.from_name(error_metric)
|
||||||
|
|
||||||
|
assert hasattr(error_metric, '__call__'), 'invalid error function'
|
||||||
|
|
||||||
|
test_batch_fn = _test_quantification_batch
|
||||||
|
if isinstance(model, MLAggregativeQuantifier):
|
||||||
|
test = MultilabelledCollection(model.preclassify(test.instances), test.labels)
|
||||||
|
test_batch_fn = _test_aggregation_batch
|
||||||
|
|
||||||
|
with qp.util.temp_seed(random_seed):
|
||||||
|
test_indexes = list(test.natural_sampling_index_generator(sample_size=sample_size, repeats=repeats))
|
||||||
|
|
||||||
|
errs = test_batch_fn(tuple([model, test, test_indexes, error_metric]))
|
||||||
|
return np.mean(errs)
|
||||||
|
|
||||||
|
|
||||||
|
def ml_artificial_prevalence_evaluation(model,
|
||||||
|
test:MultilabelledCollection,
|
||||||
|
sample_size,
|
||||||
|
n_prevalences=21,
|
||||||
|
repeats=10,
|
||||||
|
error_metric:Union[str,Callable]='mae',
|
||||||
|
random_seed=42):
|
||||||
|
|
||||||
|
if isinstance(error_metric, str):
|
||||||
|
error_metric = qp.error.from_name(error_metric)
|
||||||
|
|
||||||
|
assert hasattr(error_metric, '__call__'), 'invalid error function'
|
||||||
|
|
||||||
|
test_batch_fn = _test_quantification_batch
|
||||||
|
if isinstance(model, MLAggregativeQuantifier):
|
||||||
|
test = MultilabelledCollection(model.preclassify(test.instances), test.labels)
|
||||||
|
test_batch_fn = _test_aggregation_batch
|
||||||
|
|
||||||
|
test_indexes = []
|
||||||
|
with qp.util.temp_seed(random_seed):
|
||||||
|
for cat in test.classes_:
|
||||||
|
test_indexes.append(list(test.artificial_sampling_index_generator(sample_size=sample_size,
|
||||||
|
category=cat,
|
||||||
|
n_prevalences=n_prevalences,
|
||||||
|
repeats=repeats)))
|
||||||
|
|
||||||
|
args = [(model, test, indexes, error_metric) for indexes in test_indexes]
|
||||||
|
macro_errs = qp.util.parallel(test_batch_fn, args, n_jobs=-1)
|
||||||
|
|
||||||
|
return np.mean(macro_errs)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_quantification_batch(args):
|
||||||
|
model, test, indexes, error_metric = args
|
||||||
|
errs = []
|
||||||
|
for index in indexes:
|
||||||
|
sample = test.sampling_from_index(index)
|
||||||
|
estim_prevs = model.quantify(sample.instances)
|
||||||
|
true_prevs = sample.prevalence()
|
||||||
|
errs.append(error_metric(true_prevs, estim_prevs))
|
||||||
|
return errs
|
||||||
|
|
||||||
|
|
||||||
|
def _test_aggregation_batch(args):
|
||||||
|
model, preclassified_test, indexes, error_metric = args
|
||||||
|
errs = []
|
||||||
|
for index in indexes:
|
||||||
|
sample = preclassified_test.sampling_from_index(index)
|
||||||
|
estim_prevs = model.aggregate(sample.instances)
|
||||||
|
true_prevs = sample.prevalence()
|
||||||
|
errs.append(error_metric(true_prevs, estim_prevs))
|
||||||
|
return errs
|
|
@ -0,0 +1,222 @@
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
|
from sklearn.multioutput import MultiOutputRegressor
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from sklearn.svm import LinearSVC, LinearSVR
|
||||||
|
from sklearn.linear_model import LogisticRegression, Ridge, Lasso, LassoCV, MultiTaskLassoCV, LassoLars, LassoLarsCV, \
|
||||||
|
ElasticNet, MultiTaskElasticNetCV, MultiTaskElasticNet, LinearRegression, ARDRegression, BayesianRidge, SGDRegressor
|
||||||
|
|
||||||
|
import quapy as qp
|
||||||
|
from MultiLabel.mlclassification import MultilabelStackedClassifier
|
||||||
|
from MultiLabel.mldata import MultilabelledCollection
|
||||||
|
from method.aggregative import CC, ACC, PACC, AggregativeQuantifier
|
||||||
|
from method.base import BaseQuantifier
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class MLQuantifier:
|
||||||
|
@abstractmethod
|
||||||
|
def fit(self, data: MultilabelledCollection): ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def quantify(self, instances): ...
|
||||||
|
|
||||||
|
|
||||||
|
class MLAggregativeQuantifier(MLQuantifier):
|
||||||
|
def fit(self, data:MultilabelledCollection):
|
||||||
|
self.learner.fit(*data.Xy)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def preclassify(self, instances): ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def aggregate(self, predictions): ...
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
predictions = self.preclassify(instances)
|
||||||
|
return self.aggregate(predictions)
|
||||||
|
|
||||||
|
|
||||||
|
class MLCC(MLAggregativeQuantifier):
|
||||||
|
def __init__(self, mlcls):
|
||||||
|
self.learner = mlcls
|
||||||
|
|
||||||
|
def preclassify(self, instances):
|
||||||
|
return self.learner.predict(instances)
|
||||||
|
|
||||||
|
def aggregate(self, predictions):
|
||||||
|
pos_prev = predictions.mean(axis=0)
|
||||||
|
neg_prev = 1 - pos_prev
|
||||||
|
return np.asarray([neg_prev, pos_prev]).T
|
||||||
|
|
||||||
|
|
||||||
|
class MLPCC(MLCC):
|
||||||
|
def __init__(self, mlcls):
|
||||||
|
self.learner = mlcls
|
||||||
|
|
||||||
|
def preclassify(self, instances):
|
||||||
|
return self.learner.predict_proba(instances)
|
||||||
|
|
||||||
|
|
||||||
|
class MLACC(MLCC):
|
||||||
|
def __init__(self, mlcls):
|
||||||
|
self.learner = mlcls
|
||||||
|
|
||||||
|
def fit(self, data:MultilabelledCollection, train_prop=0.6):
|
||||||
|
self.classes_ = data.classes_
|
||||||
|
train, val = data.train_test_split(train_prop=train_prop)
|
||||||
|
self.learner.fit(*train.Xy)
|
||||||
|
val_predictions = self.preclassify(val.instances)
|
||||||
|
self.Pte_cond_estim_ = []
|
||||||
|
for c in data.classes_:
|
||||||
|
pos_c = val.labels[:,c].sum()
|
||||||
|
neg_c = len(val) - pos_c
|
||||||
|
self.Pte_cond_estim_.append(confusion_matrix(val.labels[:,c], val_predictions[:,c]).T / np.array([neg_c, pos_c]))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def preclassify(self, instances):
|
||||||
|
return self.learner.predict(instances)
|
||||||
|
|
||||||
|
def aggregate(self, predictions):
|
||||||
|
cc_prevs = super(MLACC, self).aggregate(predictions)
|
||||||
|
acc_prevs = np.asarray([ACC.solve_adjustment(self.Pte_cond_estim_[c], cc_prevs[c]) for c in self.classes_])
|
||||||
|
return acc_prevs
|
||||||
|
|
||||||
|
|
||||||
|
class MLPACC(MLPCC):
|
||||||
|
def __init__(self, mlcls):
|
||||||
|
self.learner = mlcls
|
||||||
|
|
||||||
|
def fit(self, data:MultilabelledCollection, train_prop=0.6):
|
||||||
|
self.classes_ = data.classes_
|
||||||
|
train, val = data.train_test_split(train_prop=train_prop)
|
||||||
|
self.learner.fit(*train.Xy)
|
||||||
|
val_posteriors = self.preclassify(val.instances)
|
||||||
|
self.Pte_cond_estim_ = []
|
||||||
|
for c in data.classes_:
|
||||||
|
pos_posteriors = val_posteriors[:,c]
|
||||||
|
c_posteriors = np.asarray([1-pos_posteriors, pos_posteriors]).T
|
||||||
|
self.Pte_cond_estim_.append(PACC.getPteCondEstim([0,1], val.labels[:,c], c_posteriors))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def aggregate(self, posteriors):
|
||||||
|
pcc_prevs = super(MLPACC, self).aggregate(posteriors)
|
||||||
|
pacc_prevs = np.asarray([ACC.solve_adjustment(self.Pte_cond_estim_[c], pcc_prevs[c]) for c in self.classes_])
|
||||||
|
return pacc_prevs
|
||||||
|
|
||||||
|
|
||||||
|
class MultilabelNaiveQuantifier(MLQuantifier):
|
||||||
|
def __init__(self, q:BaseQuantifier, n_jobs=-1):
|
||||||
|
self.q = q
|
||||||
|
self.estimators = None
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
|
def fit(self, data:MultilabelledCollection):
|
||||||
|
self.classes_ = data.classes_
|
||||||
|
|
||||||
|
def cat_job(lc):
|
||||||
|
return deepcopy(self.q).fit(lc)
|
||||||
|
|
||||||
|
self.estimators = qp.util.parallel(cat_job, data.genLabelledCollections(), n_jobs=self.n_jobs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
pos_prevs = np.zeros(len(self.classes_), dtype=float)
|
||||||
|
for c in self.classes_:
|
||||||
|
pos_prevs[c] = self.estimators[c].quantify(instances)[1]
|
||||||
|
neg_prevs = 1-pos_prevs
|
||||||
|
return np.asarray([neg_prevs, pos_prevs]).T
|
||||||
|
|
||||||
|
|
||||||
|
class MultilabelNaiveAggregativeQuantifier(MultilabelNaiveQuantifier, MLAggregativeQuantifier):
|
||||||
|
def __init__(self, q:AggregativeQuantifier, n_jobs=-1):
|
||||||
|
assert isinstance(q, AggregativeQuantifier), 'the quantifier is not of type aggregative!'
|
||||||
|
self.q = q
|
||||||
|
self.estimators = None
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
|
def preclassify(self, instances):
|
||||||
|
return np.asarray([q.preclassify(instances) for q in self.estimators]).swapaxes(0,1)
|
||||||
|
|
||||||
|
def aggregate(self, predictions):
|
||||||
|
pos_prevs = np.zeros(len(self.classes_), dtype=float)
|
||||||
|
for c in self.classes_:
|
||||||
|
pos_prevs[c] = self.estimators[c].aggregate(predictions[:,c])[1]
|
||||||
|
neg_prevs = 1 - pos_prevs
|
||||||
|
return np.asarray([neg_prevs, pos_prevs]).T
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
predictions = self.preclassify(instances)
|
||||||
|
return self.aggregate(predictions)
|
||||||
|
|
||||||
|
|
||||||
|
class MultilabelRegressionQuantification:
|
||||||
|
def __init__(self, base_quantifier=CC(LinearSVC()), regression='ridge', n_samples=500, sample_size=500, norm=True,
|
||||||
|
means=True, stds=True):
|
||||||
|
assert regression in ['ridge', 'svr'], 'unknown regression model'
|
||||||
|
self.estimator = MultilabelNaiveQuantifier(base_quantifier)
|
||||||
|
if regression == 'ridge':
|
||||||
|
self.reg = Ridge(normalize=norm)
|
||||||
|
elif regression == 'svr':
|
||||||
|
self.reg = MultiOutputRegressor(LinearSVR())
|
||||||
|
# self.reg = MultiTaskLassoCV(normalize=norm)
|
||||||
|
# self.reg = KernelRidge(kernel='rbf')
|
||||||
|
# self.reg = LassoLarsCV(normalize=norm)
|
||||||
|
# self.reg = MultiTaskElasticNetCV(normalize=norm) <- bien
|
||||||
|
#self.reg = LinearRegression(normalize=norm) # <- bien
|
||||||
|
# self.reg = MultiOutputRegressor(ARDRegression(normalize=norm)) # <- bastante bien, incluso sin norm
|
||||||
|
# self.reg = MultiOutputRegressor(BayesianRidge(normalize=False)) # <- bastante bien, incluso sin norm
|
||||||
|
# self.reg = MultiOutputRegressor(SGDRegressor()) # lento, no va
|
||||||
|
self.regression = regression
|
||||||
|
self.n_samples = n_samples
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.norm = StandardScaler()
|
||||||
|
self.means = means
|
||||||
|
self.stds = stds
|
||||||
|
|
||||||
|
def fit(self, data:MultilabelledCollection):
|
||||||
|
self.classes_ = data.classes_
|
||||||
|
tr, te = data.train_test_split()
|
||||||
|
self.estimator.fit(tr)
|
||||||
|
samples_mean = []
|
||||||
|
samples_std = []
|
||||||
|
Xs = []
|
||||||
|
ys = []
|
||||||
|
for sample in te.natural_sampling_generator(sample_size=self.sample_size, repeats=self.n_samples):
|
||||||
|
ys.append(sample.prevalence()[:,1])
|
||||||
|
Xs.append(self.estimator.quantify(sample.instances)[:,1])
|
||||||
|
if self.means:
|
||||||
|
samples_mean.append(sample.instances.mean(axis=0).getA().flatten())
|
||||||
|
if self.stds:
|
||||||
|
samples_std.append(sample.instances.todense().std(axis=0).getA().flatten())
|
||||||
|
Xs = np.asarray(Xs)
|
||||||
|
ys = np.asarray(ys)
|
||||||
|
if self.means:
|
||||||
|
samples_mean = np.asarray(samples_mean)
|
||||||
|
Xs = np.hstack([Xs, samples_mean])
|
||||||
|
if self.stds:
|
||||||
|
samples_std = np.asarray(samples_std)
|
||||||
|
Xs = np.hstack([Xs, samples_std])
|
||||||
|
Xs = self.norm.fit_transform(Xs)
|
||||||
|
self.reg.fit(Xs, ys)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
Xs = self.estimator.quantify(instances)[:,1].reshape(1,-1)
|
||||||
|
if self.means:
|
||||||
|
sample_mean = instances.mean(axis=0).getA()
|
||||||
|
Xs = np.hstack([Xs, sample_mean])
|
||||||
|
if self.stds:
|
||||||
|
sample_std = instances.todense().std(axis=0).getA()
|
||||||
|
Xs = np.hstack([Xs, sample_std])
|
||||||
|
Xs = self.norm.transform(Xs)
|
||||||
|
Xs = self.reg.predict(Xs)
|
||||||
|
Xs = self.norm.inverse_transform(Xs)
|
||||||
|
adjusted = np.clip(Xs, 0, 1)
|
||||||
|
adjusted = adjusted.flatten()
|
||||||
|
neg_prevs = 1-adjusted
|
||||||
|
return np.asarray([neg_prevs, adjusted]).T
|
|
@ -1,27 +1,18 @@
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
from sklearn.calibration import CalibratedClassifierCV
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.kernel_ridge import KernelRidge
|
from sklearn.multioutput import ClassifierChain
|
||||||
from sklearn.linear_model import LogisticRegression, Ridge, Lasso, LassoCV, MultiTaskLassoCV, LassoLars, LassoLarsCV, \
|
|
||||||
ElasticNet, MultiTaskElasticNetCV, MultiTaskElasticNet, LinearRegression, ARDRegression, BayesianRidge, SGDRegressor
|
|
||||||
from sklearn.metrics import f1_score
|
|
||||||
from sklearn.multiclass import OneVsRestClassifier
|
|
||||||
from sklearn.multioutput import MultiOutputRegressor
|
|
||||||
from sklearn.svm import LinearSVC
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from functional import artificial_prevalence_sampling
|
from MultiLabel.mlclassification import MultilabelStackedClassifier
|
||||||
|
from MultiLabel.mldata import MultilabelledCollection
|
||||||
|
from MultiLabel.mlquantification import MultilabelNaiveQuantifier, MLCC, MLPCC, MultilabelRegressionQuantification, \
|
||||||
|
MLACC, \
|
||||||
|
MLPACC, MultilabelNaiveAggregativeQuantifier
|
||||||
from method.aggregative import PACC, CC, EMQ, PCC, ACC, HDy
|
from method.aggregative import PACC, CC, EMQ, PCC, ACC, HDy
|
||||||
from method.base import BaseQuantifier
|
|
||||||
from quapy.data import from_rcv2_lang_file, LabelledCollection
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from data.dataset import Dataset
|
from data.dataset import Dataset
|
||||||
|
from mlevaluation import ml_natural_prevalence_evaluation, ml_artificial_prevalence_evaluation
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def cls():
|
def cls():
|
||||||
|
@ -32,303 +23,68 @@ def cls():
|
||||||
def calibratedCls():
|
def calibratedCls():
|
||||||
return CalibratedClassifierCV(cls())
|
return CalibratedClassifierCV(cls())
|
||||||
|
|
||||||
|
# DEBUG=True
|
||||||
|
|
||||||
class MultilabelledCollection:
|
# if DEBUG:
|
||||||
def __init__(self, instances, labels):
|
|
||||||
assert labels.ndim==2, 'data does not seem to be multilabel'
|
|
||||||
self.instances = instances
|
|
||||||
self.labels = labels
|
|
||||||
self.classes_ = np.arange(labels.shape[1])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, path: str, loader_func: callable):
|
|
||||||
return MultilabelledCollection(*loader_func(path))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.instances.shape[0]
|
|
||||||
|
|
||||||
def prevalence(self):
|
|
||||||
# return self.labels.mean(axis=0)
|
|
||||||
pos = self.labels.mean(axis=0)
|
|
||||||
neg = 1-pos
|
|
||||||
return np.asarray([neg, pos]).T
|
|
||||||
|
|
||||||
def counts(self):
|
|
||||||
return self.labels.sum(axis=0)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_classes(self):
|
|
||||||
return len(self.classes_)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def binary(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __gen_index(self):
|
|
||||||
return np.arange(len(self))
|
|
||||||
|
|
||||||
def sampling_multi_index(self, size, cat, prev=None):
|
|
||||||
if prev is None: # no prevalence was indicated; returns an index for uniform sampling
|
|
||||||
return np.random.choice(len(self), size, replace=size>len(self))
|
|
||||||
aux = LabelledCollection(self.__gen_index(), self.labels[:,cat])
|
|
||||||
return aux.sampling_index(size, *[1-prev, prev])
|
|
||||||
|
|
||||||
def uniform_sampling_multi_index(self, size):
|
|
||||||
return np.random.choice(len(self), size, replace=size>len(self))
|
|
||||||
|
|
||||||
def uniform_sampling(self, size):
|
|
||||||
unif_index = self.uniform_sampling_multi_index(size)
|
|
||||||
return self.sampling_from_index(unif_index)
|
|
||||||
|
|
||||||
def sampling(self, size, category, prev=None):
|
|
||||||
prev_index = self.sampling_multi_index(size, category, prev)
|
|
||||||
return self.sampling_from_index(prev_index)
|
|
||||||
|
|
||||||
def sampling_from_index(self, index):
|
|
||||||
documents = self.instances[index]
|
|
||||||
labels = self.labels[index, :]
|
|
||||||
return MultilabelledCollection(documents, labels)
|
|
||||||
|
|
||||||
def train_test_split(self, train_prop=0.6, random_state=None):
|
|
||||||
tr_docs, te_docs, tr_labels, te_labels = \
|
|
||||||
train_test_split(self.instances, self.labels, train_size=train_prop, random_state=random_state)
|
|
||||||
return MultilabelledCollection(tr_docs, tr_labels), MultilabelledCollection(te_docs, te_labels)
|
|
||||||
|
|
||||||
def artificial_sampling_generator(self, sample_size, category, n_prevalences=101, repeats=1):
|
|
||||||
dimensions = 2
|
|
||||||
for prevs in artificial_prevalence_sampling(dimensions, n_prevalences, repeats).flatten():
|
|
||||||
yield self.sampling(sample_size, category, prevs)
|
|
||||||
|
|
||||||
def artificial_sampling_index_generator(self, sample_size, category, n_prevalences=101, repeats=1):
|
|
||||||
dimensions = 2
|
|
||||||
for prevs in artificial_prevalence_sampling(dimensions, n_prevalences, repeats).flatten():
|
|
||||||
yield self.sampling_multi_index(sample_size, category, prevs)
|
|
||||||
|
|
||||||
def natural_sampling_generator(self, sample_size, repeats=100):
|
|
||||||
for _ in range(repeats):
|
|
||||||
yield self.uniform_sampling(sample_size)
|
|
||||||
|
|
||||||
def natural_sampling_index_generator(self, sample_size, repeats=100):
|
|
||||||
for _ in range(repeats):
|
|
||||||
yield self.uniform_sampling_multi_index(sample_size)
|
|
||||||
|
|
||||||
def asLabelledCollection(self, category):
|
|
||||||
return LabelledCollection(self.instances, self.labels[:,category])
|
|
||||||
|
|
||||||
def genLabelledCollections(self):
|
|
||||||
for c in self.classes_:
|
|
||||||
yield self.asLabelledCollection(c)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def Xy(self):
|
|
||||||
return self.instances, self.labels
|
|
||||||
|
|
||||||
|
|
||||||
class MultilabelClassifier: # aka Funnelling Monolingual
|
|
||||||
def __init__(self, base_estimator=LogisticRegression()):
|
|
||||||
if not hasattr(base_estimator, 'predict_proba'):
|
|
||||||
print('the estimator does not seem to be probabilistic: calibrating')
|
|
||||||
base_estimator = CalibratedClassifierCV(base_estimator)
|
|
||||||
self.base = deepcopy(OneVsRestClassifier(base_estimator))
|
|
||||||
self.meta = deepcopy(OneVsRestClassifier(base_estimator))
|
|
||||||
self.norm = StandardScaler()
|
|
||||||
|
|
||||||
def fit(self, X, y):
|
|
||||||
assert y.ndim==2, 'the dataset does not seem to be multi-label'
|
|
||||||
self.base.fit(X, y)
|
|
||||||
P = self.base.predict_proba(X)
|
|
||||||
P = self.norm.fit_transform(P)
|
|
||||||
self.meta.fit(P, y)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def predict(self, X):
|
|
||||||
P = self.base.predict_proba(X)
|
|
||||||
P = self.norm.transform(P)
|
|
||||||
return self.meta.predict(P)
|
|
||||||
|
|
||||||
def predict_proba(self, X):
|
|
||||||
P = self.base.predict_proba(X)
|
|
||||||
P = self.norm.transform(P)
|
|
||||||
return self.meta.predict_proba(P)
|
|
||||||
|
|
||||||
class MLCC:
|
|
||||||
def __init__(self, mlcls:MultilabelClassifier):
|
|
||||||
self.mlcls = mlcls
|
|
||||||
|
|
||||||
def fit(self, data:MultilabelledCollection):
|
|
||||||
self.mlcls.fit(*data.Xy)
|
|
||||||
|
|
||||||
def quantify(self, instances):
|
|
||||||
pred = self.mlcls.predict(instances)
|
|
||||||
pos_prev = pred.mean(axis=0)
|
|
||||||
neg_prev = 1-pos_prev
|
|
||||||
return np.asarray([neg_prev, pos_prev]).T
|
|
||||||
|
|
||||||
|
|
||||||
class MLPCC:
|
|
||||||
def __init__(self, mlcls: MultilabelClassifier):
|
|
||||||
self.mlcls = mlcls
|
|
||||||
|
|
||||||
def fit(self, data: MultilabelledCollection):
|
|
||||||
self.mlcls.fit(*data.Xy)
|
|
||||||
|
|
||||||
def quantify(self, instances):
|
|
||||||
pred = self.mlcls.predict_proba(instances)
|
|
||||||
pos_prev = pred.mean(axis=0)
|
|
||||||
neg_prev = 1 - pos_prev
|
|
||||||
return np.asarray([neg_prev, pos_prev]).T
|
|
||||||
|
|
||||||
|
|
||||||
class MultilabelQuantifier:
|
|
||||||
def __init__(self, q:BaseQuantifier, n_jobs=-1):
|
|
||||||
self.q = q
|
|
||||||
self.estimators = None
|
|
||||||
self.n_jobs = n_jobs
|
|
||||||
|
|
||||||
def fit(self, data:MultilabelledCollection):
|
|
||||||
self.classes_ = data.classes_
|
|
||||||
|
|
||||||
def cat_job(lc):
|
|
||||||
return deepcopy(self.q).fit(lc)
|
|
||||||
|
|
||||||
self.estimators = qp.util.parallel(cat_job, data.genLabelledCollections(), n_jobs=self.n_jobs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def quantify(self, instances):
|
|
||||||
pos_prevs = np.zeros(len(self.classes_), dtype=float)
|
|
||||||
for c in self.classes_:
|
|
||||||
pos_prevs[c] = self.estimators[c].quantify(instances)[1]
|
|
||||||
neg_prevs = 1-pos_prevs
|
|
||||||
return np.asarray([neg_prevs, pos_prevs]).T
|
|
||||||
|
|
||||||
|
|
||||||
class MultilabelRegressionQuantification:
|
|
||||||
def __init__(self, base_quantifier=CC(LinearSVC()), regression='ridge', n_samples=500, sample_size=500, norm=True,
|
|
||||||
means=True, stds=True):
|
|
||||||
assert regression in ['ridge'], 'unknown regression model'
|
|
||||||
self.estimator = MultilabelQuantifier(base_quantifier)
|
|
||||||
if regression == 'ridge':
|
|
||||||
self.reg = Ridge(normalize=norm)
|
|
||||||
# self.reg = MultiTaskLassoCV(normalize=norm)
|
|
||||||
# self.reg = KernelRidge(kernel='rbf')
|
|
||||||
# self.reg = LassoLarsCV(normalize=norm)
|
|
||||||
# self.reg = MultiTaskElasticNetCV(normalize=norm) <- bien
|
|
||||||
#self.reg = LinearRegression(normalize=norm) # <- bien
|
|
||||||
# self.reg = MultiOutputRegressor(ARDRegression(normalize=norm)) # <- bastante bien, incluso sin norm
|
|
||||||
# self.reg = MultiOutputRegressor(BayesianRidge(normalize=False)) # <- bastante bien, incluso sin norm
|
|
||||||
# self.reg = MultiOutputRegressor(SGDRegressor()) # lento, no va
|
|
||||||
self.regression = regression
|
|
||||||
self.n_samples = n_samples
|
|
||||||
self.sample_size = sample_size
|
|
||||||
# self.norm = StandardScaler()
|
|
||||||
self.means = means
|
|
||||||
self.stds = stds
|
|
||||||
|
|
||||||
def fit(self, data:MultilabelledCollection):
|
|
||||||
self.classes_ = data.classes_
|
|
||||||
tr, te = data.train_test_split()
|
|
||||||
self.estimator.fit(tr)
|
|
||||||
samples_mean = []
|
|
||||||
samples_std = []
|
|
||||||
Xs = []
|
|
||||||
ys = []
|
|
||||||
for sample in te.natural_sampling_generator(sample_size=self.sample_size, repeats=self.n_samples):
|
|
||||||
ys.append(sample.prevalence()[:,1])
|
|
||||||
Xs.append(self.estimator.quantify(sample.instances)[:,1])
|
|
||||||
if self.means:
|
|
||||||
samples_mean.append(sample.instances.mean(axis=0).getA().flatten())
|
|
||||||
if self.stds:
|
|
||||||
samples_std.append(sample.instances.todense().std(axis=0).getA().flatten())
|
|
||||||
Xs = np.asarray(Xs)
|
|
||||||
ys = np.asarray(ys)
|
|
||||||
if self.means:
|
|
||||||
samples_mean = np.asarray(samples_mean)
|
|
||||||
Xs = np.hstack([Xs, samples_mean])
|
|
||||||
if self.stds:
|
|
||||||
samples_std = np.asarray(samples_std)
|
|
||||||
Xs = np.hstack([Xs, samples_std])
|
|
||||||
# Xs = self.norm.fit_transform(Xs)
|
|
||||||
self.reg.fit(Xs, ys)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def quantify(self, instances):
|
|
||||||
Xs = self.estimator.quantify(instances)[:,1].reshape(1,-1)
|
|
||||||
if self.means:
|
|
||||||
sample_mean = instances.mean(axis=0).getA()
|
|
||||||
Xs = np.hstack([Xs, sample_mean])
|
|
||||||
if self.stds:
|
|
||||||
sample_std = instances.todense().std(axis=0).getA()
|
|
||||||
Xs = np.hstack([Xs, sample_std])
|
|
||||||
# Xs = self.norm.transform(Xs)
|
|
||||||
adjusted = self.reg.predict(Xs)
|
|
||||||
adjusted = np.clip(adjusted, 0, 1)
|
|
||||||
adjusted = adjusted.flatten()
|
|
||||||
neg_prevs = 1-adjusted
|
|
||||||
return np.asarray([neg_prevs, adjusted]).T
|
|
||||||
|
|
||||||
sample_size = 250
|
sample_size = 250
|
||||||
n_samples = 1000
|
n_samples = 5000
|
||||||
|
|
||||||
|
|
||||||
def models():
|
def models():
|
||||||
yield 'CC', MultilabelQuantifier(CC(cls()))
|
# yield 'NaiveCC', MultilabelNaiveAggregativeQuantifier(CC(cls()))
|
||||||
yield 'PCC', MultilabelQuantifier(PCC(cls()))
|
# yield 'NaivePCC', MultilabelNaiveAggregativeQuantifier(PCC(cls()))
|
||||||
yield 'MLCC', MLCC(MultilabelClassifier(cls()))
|
# yield 'NaiveACC', MultilabelNaiveAggregativeQuantifier(ACC(cls()))
|
||||||
yield 'MLPCC', MLPCC(MultilabelClassifier(cls()))
|
# yield 'NaivePACC', MultilabelNaiveAggregativeQuantifier(PACC(cls()))
|
||||||
# yield 'PACC', MultilabelQuantifier(PACC(cls()))
|
|
||||||
# yield 'EMQ', MultilabelQuantifier(EMQ(calibratedCls()))
|
# yield 'EMQ', MultilabelQuantifier(EMQ(calibratedCls()))
|
||||||
common={'sample_size':sample_size, 'n_samples': n_samples, 'norm': True}
|
# yield 'StackCC', MLCC(MultilabelStackedClassifier(cls()))
|
||||||
# yield 'MRQ-CC', MultilabelRegressionQuantification(base_quantifier=CC(cls()), **common)
|
# yield 'StackPCC', MLPCC(MultilabelStackedClassifier(cls()))
|
||||||
yield 'MRQ-PCC', MultilabelRegressionQuantification(base_quantifier=PCC(cls()), **common)
|
# yield 'StackACC', MLACC(MultilabelStackedClassifier(cls()))
|
||||||
yield 'MRQ-PACC', MultilabelRegressionQuantification(base_quantifier=PACC(cls()), **common)
|
# yield 'StackPACC', MLPACC(MultilabelStackedClassifier(cls()))
|
||||||
|
# yield 'ChainCC', MLCC(ClassifierChain(cls(), cv=None, order='random'))
|
||||||
|
# yield 'ChainPCC', MLPCC(ClassifierChain(cls(), cv=None, order='random'))
|
||||||
|
# yield 'ChainACC', MLACC(ClassifierChain(cls(), cv=None, order='random'))
|
||||||
|
# yield 'ChainPACC', MLPACC(ClassifierChain(cls(), cv=None, order='random'))
|
||||||
|
common={'sample_size':sample_size, 'n_samples': n_samples, 'norm': True, 'means':False, 'stds':False}
|
||||||
|
yield 'MRQ-CC', MultilabelRegressionQuantification(base_quantifier=CC(cls()), regression='svr', **common)
|
||||||
|
yield 'MRQ-PCC', MultilabelRegressionQuantification(base_quantifier=PCC(cls()), regression='svr', **common)
|
||||||
|
yield 'MRQ-ACC', MultilabelRegressionQuantification(base_quantifier=ACC(cls()), regression='svr', **common)
|
||||||
|
yield 'MRQ-PACC', MultilabelRegressionQuantification(base_quantifier=PACC(cls()), regression='svr', **common)
|
||||||
|
|
||||||
|
|
||||||
dataset = 'reuters21578'
|
dataset = 'reuters21578'
|
||||||
data = Dataset.load(dataset, pickle_path=f'./pickles/{dataset}.pickle')
|
picklepath = '/home/moreo/word-class-embeddings/pickles'
|
||||||
|
data = Dataset.load(dataset, pickle_path=f'{picklepath}/{dataset}.pickle')
|
||||||
|
|
||||||
Xtr, Xte = data.vectorize()
|
Xtr, Xte = data.vectorize()
|
||||||
ytr = data.devel_labelmatrix.todense().getA()
|
ytr = data.devel_labelmatrix.todense().getA()
|
||||||
yte = data.test_labelmatrix.todense().getA()
|
yte = data.test_labelmatrix.todense().getA()
|
||||||
|
|
||||||
most_populadted = np.argsort(ytr.sum(axis=0))[-25:]
|
# remove categories with < 10 training documents
|
||||||
ytr = ytr[:, most_populadted]
|
to_keep = np.logical_and(ytr.sum(axis=0)>=50, yte.sum(axis=0)>=50)
|
||||||
yte = yte[:, most_populadted]
|
ytr = ytr[:, to_keep]
|
||||||
|
yte = yte[:, to_keep]
|
||||||
|
print(f'num categories = {ytr.shape[1]}')
|
||||||
|
|
||||||
train = MultilabelledCollection(Xtr, ytr)
|
train = MultilabelledCollection(Xtr, ytr)
|
||||||
test = MultilabelledCollection(Xte, yte)
|
test = MultilabelledCollection(Xte, yte)
|
||||||
|
|
||||||
print(f'Train-prev: {train.prevalence()[:,1]}')
|
# print(f'Train-prev: {train.prevalence()[:,1]}')
|
||||||
print(f'Test-prev: {test.prevalence()[:,1]}')
|
print(f'Train-counts: {train.counts()}')
|
||||||
|
# print(f'Test-prev: {test.prevalence()[:,1]}')
|
||||||
|
print(f'Test-counts: {test.counts()}')
|
||||||
print(f'MLPE: {qp.error.mae(train.prevalence(), test.prevalence()):.5f}')
|
print(f'MLPE: {qp.error.mae(train.prevalence(), test.prevalence()):.5f}')
|
||||||
|
|
||||||
# print('NPP:')
|
fit_models = {model_name:model.fit(train) for model_name,model in tqdm(models(), 'fitting', total=6)}
|
||||||
# test_indexes = list(test.natural_sampling_index_generator(sample_size=sample_size, repeats=100))
|
|
||||||
# for model_name, model in models():
|
print('NPP:')
|
||||||
# model.fit(train)
|
for model_name, model in fit_models.items():
|
||||||
# errs = []
|
err = ml_natural_prevalence_evaluation(model, test, sample_size, repeats=100)
|
||||||
# for index in test_indexes:
|
print(f'{model_name:10s}\tmae={err:.5f}')
|
||||||
# sample = test.sampling_from_index(index)
|
|
||||||
# estim_prevs = model.quantify(sample.instances)
|
|
||||||
# true_prevs = sample.prevalence()
|
|
||||||
# errs.append(qp.error.mae(true_prevs, estim_prevs))
|
|
||||||
# print(f'{model_name:10s}\tmae={np.mean(errs):.5f}')
|
|
||||||
|
|
||||||
print('APP:')
|
print('APP:')
|
||||||
test_indexes = []
|
for model_name, model in fit_models.items():
|
||||||
for cat in train.classes_:
|
err = ml_artificial_prevalence_evaluation(model, test, sample_size, n_prevalences=21, repeats=10)
|
||||||
test_indexes.append(list(test.artificial_sampling_index_generator(sample_size=sample_size, category=cat, n_prevalences=21, repeats=10)))
|
print(f'{model_name:10s}\tmae={err:.5f}')
|
||||||
|
|
||||||
for model_name, model in models():
|
|
||||||
model.fit(train)
|
|
||||||
macro_errs = []
|
|
||||||
for cat_indexes in test_indexes:
|
|
||||||
errs = []
|
|
||||||
for index in cat_indexes:
|
|
||||||
sample = test.sampling_from_index(index)
|
|
||||||
estim_prevs = model.quantify(sample.instances)
|
|
||||||
true_prevs = sample.prevalence()
|
|
||||||
errs.append(qp.error.mae(true_prevs, estim_prevs))
|
|
||||||
macro_errs.append(np.mean(errs))
|
|
||||||
print(f'{model_name:10s}\tmae={np.mean(macro_errs):.5f}')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,9 @@ class AggregativeQuantifier(BaseQuantifier):
|
||||||
def learner(self, value):
|
def learner(self, value):
|
||||||
self.learner_ = value
|
self.learner_ = value
|
||||||
|
|
||||||
|
def preclassify(self, instances):
|
||||||
|
return self.classify(instances)
|
||||||
|
|
||||||
def classify(self, instances):
|
def classify(self, instances):
|
||||||
return self.learner.predict(instances)
|
return self.learner.predict(instances)
|
||||||
|
|
||||||
|
@ -74,6 +77,9 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier):
|
||||||
probabilities.
|
probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def preclassify(self, instances):
|
||||||
|
return self.predict_proba(instances)
|
||||||
|
|
||||||
def posterior_probabilities(self, instances):
|
def posterior_probabilities(self, instances):
|
||||||
return self.learner.predict_proba(instances)
|
return self.learner.predict_proba(instances)
|
||||||
|
|
||||||
|
@ -316,6 +322,12 @@ class PACC(AggregativeProbabilisticQuantifier):
|
||||||
|
|
||||||
self.pcc = PCC(self.learner)
|
self.pcc = PCC(self.learner)
|
||||||
|
|
||||||
|
self.Pte_cond_estim_ = self.getPteCondEstim(classes, y, y_)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getPteCondEstim(cls, classes, y, y_):
|
||||||
# estimate the matrix with entry (i,j) being the estimate of P(yi|yj), that is, the probability that a
|
# estimate the matrix with entry (i,j) being the estimate of P(yi|yj), that is, the probability that a
|
||||||
# document that belongs to yj ends up being classified as belonging to yi
|
# document that belongs to yj ends up being classified as belonging to yi
|
||||||
n_classes = len(classes)
|
n_classes = len(classes)
|
||||||
|
@ -323,9 +335,7 @@ class PACC(AggregativeProbabilisticQuantifier):
|
||||||
for i, class_ in enumerate(classes):
|
for i, class_ in enumerate(classes):
|
||||||
confusion[i] = y_[y == class_].mean(axis=0)
|
confusion[i] = y_[y == class_].mean(axis=0)
|
||||||
|
|
||||||
self.Pte_cond_estim_ = confusion.T
|
return confusion.T
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors):
|
def aggregate(self, classif_posteriors):
|
||||||
prevs_estim = self.pcc.aggregate(classif_posteriors)
|
prevs_estim = self.pcc.aggregate(classif_posteriors)
|
||||||
|
@ -785,7 +795,7 @@ class OneVsAll(AggregativeQuantifier):
|
||||||
return self.binary_quantifier.get_params()
|
return self.binary_quantifier.get_params()
|
||||||
|
|
||||||
def _delayed_binary_classification(self, c, X):
|
def _delayed_binary_classification(self, c, X):
|
||||||
return self.dict_binary_quantifiers[c].classify(X)
|
return self.dict_binary_quantifiers[c].preclassify(X)
|
||||||
|
|
||||||
def _delayed_binary_posteriors(self, c, X):
|
def _delayed_binary_posteriors(self, c, X):
|
||||||
return self.dict_binary_quantifiers[c].posterior_probabilities(X)
|
return self.dict_binary_quantifiers[c].posterior_probabilities(X)
|
||||||
|
|
|
@ -27,7 +27,7 @@ class BaseQuantifier(metaclass=ABCMeta):
|
||||||
# based on class structure
|
# based on class structure
|
||||||
@property
|
@property
|
||||||
def binary(self):
|
def binary(self):
|
||||||
return False
|
return len(self.classes_)==2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def aggregative(self):
|
def aggregative(self):
|
||||||
|
|
Loading…
Reference in New Issue