forked from moreo/QuaPy
more examples, one-vs-all fixed
This commit is contained in:
parent
2485117f05
commit
e28abfc362
|
@ -0,0 +1,53 @@
|
|||
import quapy as qp
|
||||
from quapy.method.aggregative import MS2, OneVsAllAggregative, OneVsAllGeneric
|
||||
from quapy.method.base import getOneVsAll
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.protocol import USimplexPP
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
In this example, we will create a quantifier for tweet sentiment analysis considering three classes: negative, neutral,
|
||||
and positive. We will use a one-vs-all approach using a binary quantifier for demonstration purposes.
|
||||
"""
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 100
|
||||
|
||||
"""
|
||||
Any binary quantifier can be turned into a single-label quantifier by means of getOneVsAll function.
|
||||
This function returns an instance of OneVsAll quantifier. Actually, it either returns the subclass OneVsAllGeneric
|
||||
when the quantifier is an instance of BaseQuantifier, and it returns OneVsAllAggregative when the quantifier is
|
||||
an instance of AggregativeQuantifier. Although OneVsAllGeneric works in all cases, using OneVsAllAggregative has
|
||||
some additional advantages (namely, all the advantages that AggregativeQuantifiers enjoy, i.e., faster predictions
|
||||
during evaluation).
|
||||
"""
|
||||
quantifier = getOneVsAll(MS2(LogisticRegression()), parallel_backend="loky")
|
||||
print(f'the quantifier is an instance of {quantifier.__class__.__name__}')
|
||||
|
||||
# load a ternary dataset
|
||||
train_modsel, val = qp.datasets.fetch_twitter('hcr', for_model_selection=True, pickle=True).train_test
|
||||
|
||||
"""
|
||||
model selection: for this example, we are relying on the USimplexPP protocol, i.e., a variant of the
|
||||
artificial-prevalence protocol that generates random samples (100 in this case) for randomly picked priors
|
||||
from the unit simplex. The priors are sampled using the Kraemer algorithm. Note this is in contrast to the
|
||||
standard APP protocol, that instead explores a prefixed grid of prevalence values.
|
||||
"""
|
||||
param_grid = {
|
||||
'binary_quantifier__classifier__C': np.logspace(-2,2,5), # classifier-dependent hyperparameter
|
||||
'binary_quantifier__classifier__class_weight': ['balanced', None] # classifier-dependent hyperparameter
|
||||
}
|
||||
print('starting model selection')
|
||||
gs = GridSearchQ(quantifier, param_grid, protocol=USimplexPP(val), n_jobs=-1, verbose=True, refit=False)
|
||||
quantifier = gs.fit(train_modsel).best_model()
|
||||
|
||||
print('training on the whole training set')
|
||||
train, test = qp.datasets.fetch_twitter('hcr', for_model_selection=False, pickle=True).train_test
|
||||
quantifier.fit(train)
|
||||
|
||||
# evaluation
|
||||
mae = qp.evaluation.evaluate(quantifier, protocol=USimplexPP(test), error_metric='mae')
|
||||
|
||||
print(f'MAE = {mae:.4f}')
|
||||
|
||||
|
|
@ -83,7 +83,6 @@ Things to fix:
|
|||
- update unit tests
|
||||
- update Wikis...
|
||||
- Resolve the OneVsAll thing (it is in base.py and in aggregative.py)
|
||||
- Add a proper log?
|
||||
- improve plots
|
||||
- documentation of protocols is incomplete
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabi
|
|||
"""
|
||||
Fits the calibration in a cross-validation manner, i.e., it generates posterior probabilities for all
|
||||
training instances via cross-validation, and then retrains the classifier on all training instances.
|
||||
The posterior probabilities thus generated are used for calibrating the outpus of the classifier.
|
||||
The posterior probabilities thus generated are used for calibrating the outputs of the classifier.
|
||||
|
||||
:param X: array-like of shape `(n_samples, n_features)` with the data instances
|
||||
:param y: array-like of shape `(n_samples,)` with the class labels
|
||||
|
|
|
@ -6,21 +6,22 @@ from scipy.sparse import vstack
|
|||
from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold
|
||||
from numpy.random import RandomState
|
||||
from quapy.functional import strprev
|
||||
from quapy.util import temp_seed
|
||||
|
||||
|
||||
class LabelledCollection:
|
||||
"""
|
||||
A LabelledCollection is a set of objects each with a label associated to it. This class implements many sampling
|
||||
routines.
|
||||
|
||||
A LabelledCollection is a set of objects each with a label attached to each of them.
|
||||
This class implements several sampling routines and other utilities.
|
||||
|
||||
:param instances: array-like (np.ndarray, list, or csr_matrix are supported)
|
||||
:param labels: array-like with the same length of instances
|
||||
:param classes_: optional, list of classes from which labels are taken. If not specified, the classes are inferred
|
||||
:param classes: optional, list of classes from which labels are taken. If not specified, the classes are inferred
|
||||
from the labels. The classes must be indicated in cases in which some of the labels might have no examples
|
||||
(i.e., a prevalence of 0)
|
||||
"""
|
||||
|
||||
def __init__(self, instances, labels, classes_=None):
|
||||
def __init__(self, instances, labels, classes=None):
|
||||
if issparse(instances):
|
||||
self.instances = instances
|
||||
elif isinstance(instances, list) and len(instances) > 0 and isinstance(instances[0], str):
|
||||
|
@ -30,14 +31,14 @@ class LabelledCollection:
|
|||
self.instances = np.asarray(instances)
|
||||
self.labels = np.asarray(labels)
|
||||
n_docs = len(self)
|
||||
if classes_ is None:
|
||||
if classes is None:
|
||||
self.classes_ = np.unique(self.labels)
|
||||
self.classes_.sort()
|
||||
else:
|
||||
self.classes_ = np.unique(np.asarray(classes_))
|
||||
self.classes_ = np.unique(np.asarray(classes))
|
||||
self.classes_.sort()
|
||||
if len(set(self.labels).difference(set(classes_))) > 0:
|
||||
raise ValueError(f'labels ({set(self.labels)}) contain values not included in classes_ ({set(classes_)})')
|
||||
if len(set(self.labels).difference(set(classes))) > 0:
|
||||
raise ValueError(f'labels ({set(self.labels)}) contain values not included in classes_ ({set(classes)})')
|
||||
self.index = {class_: np.arange(n_docs)[self.labels == class_] for class_ in self.classes_}
|
||||
|
||||
@classmethod
|
||||
|
@ -101,7 +102,7 @@ class LabelledCollection:
|
|||
"""
|
||||
return self.n_classes == 2
|
||||
|
||||
def sampling_index(self, size, *prevs, shuffle=True):
|
||||
def sampling_index(self, size, *prevs, shuffle=True, random_state=None):
|
||||
"""
|
||||
Returns an index to be used to extract a random sample of desired size and desired prevalence values. If the
|
||||
prevalence values are not specified, then returns the index of a uniform sampling.
|
||||
|
@ -113,10 +114,11 @@ class LabelledCollection:
|
|||
it is constrained. E.g., for binary collections, only the prevalence `p` for the first class (as listed in
|
||||
`self.classes_` can be specified, while the other class takes prevalence value `1-p`
|
||||
:param shuffle: if set to True (default), shuffles the index before returning it
|
||||
:param random_state: seed for reproducing sampling
|
||||
:return: a np.ndarray of shape `(size)` with the indexes
|
||||
"""
|
||||
if len(prevs) == 0: # no prevalence was indicated; returns an index for uniform sampling
|
||||
return self.uniform_sampling_index(size)
|
||||
return self.uniform_sampling_index(size, random_state=random_state)
|
||||
if len(prevs) == self.n_classes - 1:
|
||||
prevs = prevs + (1 - sum(prevs),)
|
||||
assert len(prevs) == self.n_classes, 'unexpected number of prevalences'
|
||||
|
@ -129,22 +131,23 @@ class LabelledCollection:
|
|||
# (This aims at avoiding the remainder to be placed in a class for which the prevalence requested is 0.)
|
||||
n_requests = {class_: int(size * prevs[i]) for i, class_ in enumerate(self.classes_)}
|
||||
remainder = size - sum(n_requests.values())
|
||||
for rand_class in np.random.choice(self.classes_, size=remainder, p=prevs):
|
||||
n_requests[rand_class] += 1
|
||||
with temp_seed(random_state):
|
||||
for rand_class in np.random.choice(self.classes_, size=remainder, p=prevs):
|
||||
n_requests[rand_class] += 1
|
||||
|
||||
indexes_sample = []
|
||||
for class_, n_requested in n_requests.items():
|
||||
n_candidates = len(self.index[class_])
|
||||
index_sample = self.index[class_][
|
||||
np.random.choice(n_candidates, size=n_requested, replace=(n_requested > n_candidates))
|
||||
] if n_requested > 0 else []
|
||||
indexes_sample = []
|
||||
for class_, n_requested in n_requests.items():
|
||||
n_candidates = len(self.index[class_])
|
||||
index_sample = self.index[class_][
|
||||
np.random.choice(n_candidates, size=n_requested, replace=(n_requested > n_candidates))
|
||||
] if n_requested > 0 else []
|
||||
|
||||
indexes_sample.append(index_sample)
|
||||
indexes_sample.append(index_sample)
|
||||
|
||||
indexes_sample = np.concatenate(indexes_sample).astype(int)
|
||||
indexes_sample = np.concatenate(indexes_sample).astype(int)
|
||||
|
||||
if shuffle:
|
||||
indexes_sample = np.random.permutation(indexes_sample)
|
||||
if shuffle:
|
||||
indexes_sample = np.random.permutation(indexes_sample)
|
||||
|
||||
return indexes_sample
|
||||
|
||||
|
@ -164,7 +167,7 @@ class LabelledCollection:
|
|||
ng = np.random
|
||||
return ng.choice(len(self), size, replace=size > len(self))
|
||||
|
||||
def sampling(self, size, *prevs, shuffle=True):
|
||||
def sampling(self, size, *prevs, shuffle=True, random_state=None):
|
||||
"""
|
||||
Return a random sample (an instance of :class:`LabelledCollection`) of desired size and desired prevalence
|
||||
values. For each class, the sampling is drawn without replacement if the requested prevalence is larger than
|
||||
|
@ -175,10 +178,11 @@ class LabelledCollection:
|
|||
it is constrained. E.g., for binary collections, only the prevalence `p` for the first class (as listed in
|
||||
`self.classes_` can be specified, while the other class takes prevalence value `1-p`
|
||||
:param shuffle: if set to True (default), shuffles the index before returning it
|
||||
:param random_state: seed for reproducing sampling
|
||||
:return: an instance of :class:`LabelledCollection` with length == `size` and prevalence close to `prevs` (or
|
||||
prevalence == `prevs` if the exact prevalence values can be met as proportions of instances)
|
||||
"""
|
||||
prev_index = self.sampling_index(size, *prevs, shuffle=shuffle)
|
||||
prev_index = self.sampling_index(size, *prevs, shuffle=shuffle, random_state=random_state)
|
||||
return self.sampling_from_index(prev_index)
|
||||
|
||||
def uniform_sampling(self, size, random_state=None):
|
||||
|
@ -204,7 +208,7 @@ class LabelledCollection:
|
|||
"""
|
||||
documents = self.instances[index]
|
||||
labels = self.labels[index]
|
||||
return LabelledCollection(documents, labels, classes_=self.classes_)
|
||||
return LabelledCollection(documents, labels, classes=self.classes_)
|
||||
|
||||
def split_stratified(self, train_prop=0.6, random_state=None):
|
||||
"""
|
||||
|
@ -221,11 +225,10 @@ class LabelledCollection:
|
|||
tr_docs, te_docs, tr_labels, te_labels = train_test_split(
|
||||
self.instances, self.labels, train_size=train_prop, stratify=self.labels, random_state=random_state
|
||||
)
|
||||
training = LabelledCollection(tr_docs, tr_labels, classes_=self.classes_)
|
||||
test = LabelledCollection(te_docs, te_labels, classes_=self.classes_)
|
||||
training = LabelledCollection(tr_docs, tr_labels, classes=self.classes_)
|
||||
test = LabelledCollection(te_docs, te_labels, classes=self.classes_)
|
||||
return training, test
|
||||
|
||||
|
||||
def split_random(self, train_prop=0.6, random_state=None):
|
||||
"""
|
||||
Returns two instances of :class:`LabelledCollection` split randomly from this collection, at desired
|
||||
|
@ -261,20 +264,33 @@ class LabelledCollection:
|
|||
:return: a :class:`LabelledCollection` representing the union of both collections
|
||||
"""
|
||||
if not all(np.sort(self.classes_)==np.sort(other.classes_)):
|
||||
raise NotImplementedError('unsupported operation for collections on different classes')
|
||||
raise NotImplementedError(f'unsupported operation for collections on different classes; '
|
||||
f'expected {self.classes_}, found {other.classes_}')
|
||||
return LabelledCollection.mix(self, other)
|
||||
|
||||
if other is None:
|
||||
return self
|
||||
elif issparse(self.instances) and issparse(other.instances):
|
||||
join_instances = vstack([self.instances, other.instances])
|
||||
elif isinstance(self.instances, list) and isinstance(other.instances, list):
|
||||
join_instances = self.instances + other.instances
|
||||
elif isinstance(self.instances, np.ndarray) and isinstance(other.instances, np.ndarray):
|
||||
join_instances = np.concatenate([self.instances, other.instances])
|
||||
@classmethod
|
||||
def mix(cls, a:'LabelledCollection', b:'LabelledCollection'):
|
||||
"""
|
||||
Returns a new :class:`LabelledCollection` as the union of this collection with another collection.
|
||||
|
||||
:param a: instance of :class:`LabelledCollection`
|
||||
:param b: instance of :class:`LabelledCollection`
|
||||
:return: a :class:`LabelledCollection` representing the union of both collections
|
||||
"""
|
||||
if a is None: return b
|
||||
if b is None: return a
|
||||
elif issparse(a.instances) and issparse(b.instances):
|
||||
join_instances = vstack([a.instances, b.instances])
|
||||
elif isinstance(a.instances, list) and isinstance(b.instances, list):
|
||||
join_instances = a.instances + b.instances
|
||||
elif isinstance(a.instances, np.ndarray) and isinstance(b.instances, np.ndarray):
|
||||
join_instances = np.concatenate([a.instances, b.instances])
|
||||
else:
|
||||
raise NotImplementedError('unsupported operation for collection types')
|
||||
labels = np.concatenate([self.labels, other.labels])
|
||||
return LabelledCollection(join_instances, labels, classes_=self.classes_)
|
||||
labels = np.concatenate([a.labels, b.labels])
|
||||
classes = np.unique(np.concatenate([a.classes_, b.classes_])).sort()
|
||||
return LabelledCollection(join_instances, labels, classes=classes)
|
||||
|
||||
|
||||
@property
|
||||
def Xy(self):
|
||||
|
@ -291,7 +307,7 @@ class LabelledCollection:
|
|||
def Xp(self):
|
||||
"""
|
||||
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols from
|
||||
a `LabelledCollection` object.
|
||||
a :class:`LabelledCollection` object.
|
||||
|
||||
:return: a tuple `(instances, prevalence)` from this collection
|
||||
"""
|
||||
|
@ -357,7 +373,7 @@ class LabelledCollection:
|
|||
f'#classes={stats_["classes"]}, prevs={stats_["prevs"]}')
|
||||
return stats_
|
||||
|
||||
def kFCV(self, nfolds=5, nrepeats=1, random_state=0):
|
||||
def kFCV(self, nfolds=5, nrepeats=1, random_state=None):
|
||||
"""
|
||||
Generator of stratified folds to be used in k-fold cross validation.
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import os
|
|||
import zipfile
|
||||
from os.path import join
|
||||
import pandas as pd
|
||||
import scipy
|
||||
|
||||
from quapy.data.base import Dataset, LabelledCollection
|
||||
from quapy.data.preprocessing import text2tfidf, reduce_columns
|
||||
|
|
|
@ -14,7 +14,7 @@ import quapy.functional as F
|
|||
from classification.calibration import NBVSCalibration, BCTSCalibration, TSCalibration, VSCalibration
|
||||
from quapy.classification.svmperf import SVMperf
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
|
||||
|
||||
|
||||
# Abstract classes
|
||||
|
@ -1246,7 +1246,7 @@ MedianSweep = MS
|
|||
MedianSweep2 = MS2
|
||||
|
||||
|
||||
class OneVsAll(AggregativeQuantifier):
|
||||
class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier):
|
||||
"""
|
||||
Allows any binary quantifier to perform quantification on single-label datasets.
|
||||
The method maintains one binary quantifier for each class, and then l1-normalizes the outputs so that the
|
||||
|
@ -1257,25 +1257,19 @@ class OneVsAll(AggregativeQuantifier):
|
|||
:param binary_quantifier: a quantifier (binary) that will be employed to work on multiclass model in a
|
||||
one-vs-all manner
|
||||
:param n_jobs: number of parallel workers
|
||||
:param parallel_backend: the parallel backend for joblib (default "loky"); this is helpful for some quantifiers
|
||||
(e.g., ELM-based ones) that cannot be run with multiprocessing, since the temp dir they create during fit will
|
||||
is removed and no longer available at predict time.
|
||||
"""
|
||||
|
||||
def __init__(self, binary_quantifier, n_jobs=None):
|
||||
assert isinstance(self.binary_quantifier, BaseQuantifier), \
|
||||
def __init__(self, binary_quantifier, n_jobs=None, parallel_backend='loky'):
|
||||
assert isinstance(binary_quantifier, BaseQuantifier), \
|
||||
f'{self.binary_quantifier} does not seem to be a Quantifier'
|
||||
assert isinstance(self.binary_quantifier, AggregativeQuantifier), \
|
||||
assert isinstance(binary_quantifier, AggregativeQuantifier), \
|
||||
f'{self.binary_quantifier} does not seem to be of type Aggregative'
|
||||
self.binary_quantifier = binary_quantifier
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
assert not data.binary, \
|
||||
f'{self.__class__.__name__} expect non-binary data'
|
||||
assert fit_classifier == True, \
|
||||
'fit_classifier must be True'
|
||||
|
||||
self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
|
||||
self.__parallel(self._delayed_binary_fit, data)
|
||||
return self
|
||||
self.parallel_backend = parallel_backend
|
||||
|
||||
def classify(self, instances):
|
||||
"""
|
||||
|
@ -1292,35 +1286,16 @@ class OneVsAll(AggregativeQuantifier):
|
|||
:return: `np.ndarray`
|
||||
"""
|
||||
|
||||
classif_predictions = self.__parallel(self._delayed_binary_classification, instances)
|
||||
classif_predictions = self._parallel(self._delayed_binary_classification, instances)
|
||||
if isinstance(self.binary_quantifier, AggregativeProbabilisticQuantifier):
|
||||
return np.swapaxes(classif_predictions, 0, 1)
|
||||
else:
|
||||
return classif_predictions.T
|
||||
|
||||
def aggregate(self, classif_predictions):
|
||||
prevalences = self.__parallel(self._delayed_binary_aggregate, classif_predictions)
|
||||
prevalences = self._parallel(self._delayed_binary_aggregate, classif_predictions)
|
||||
return F.normalize_prevalence(prevalences)
|
||||
|
||||
def __parallel(self, func, *args, **kwargs):
|
||||
return np.asarray(
|
||||
# some quantifiers (in particular, ELM-based ones) cannot be run with multiprocess, since the temp dir they
|
||||
# create during the fit will be removed and be no longer available for the predict...
|
||||
Parallel(n_jobs=self.n_jobs, backend='threading')(
|
||||
delayed(func)(c, *args, **kwargs) for c in self.classes_
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def classes_(self):
|
||||
return sorted(self.dict_binary_quantifiers.keys())
|
||||
|
||||
def set_params(self, **parameters):
|
||||
self.binary_quantifier.set_params(**parameters)
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return self.binary_quantifier.get_params()
|
||||
|
||||
def _delayed_binary_classification(self, c, X):
|
||||
return self.dict_binary_quantifiers[c].classify(X)
|
||||
|
||||
|
@ -1328,7 +1303,3 @@ class OneVsAll(AggregativeQuantifier):
|
|||
# the estimation for the positive class prevalence
|
||||
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1]
|
||||
|
||||
def _delayed_binary_fit(self, c, data):
|
||||
bindata = LabelledCollection(data.instances, data.labels == c, classes_=[False, True])
|
||||
self.dict_binary_quantifiers[c].fit(bindata)
|
||||
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
import quapy as qp
|
||||
from quapy.data import LabelledCollection
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Base Quantifier abstract class
|
||||
|
@ -48,50 +50,61 @@ class BinaryQuantifier(BaseQuantifier):
|
|||
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
|
||||
|
||||
|
||||
class OneVsAllGeneric:
|
||||
class OneVsAll:
|
||||
pass
|
||||
|
||||
|
||||
def getOneVsAll(binary_quantifier, n_jobs=None, parallel_backend='loky'):
|
||||
assert isinstance(binary_quantifier, BaseQuantifier), \
|
||||
f'{binary_quantifier} does not seem to be a Quantifier'
|
||||
if isinstance(binary_quantifier, qp.method.aggregative.AggregativeQuantifier):
|
||||
return qp.method.aggregative.OneVsAllAggregative(binary_quantifier, n_jobs, parallel_backend)
|
||||
else:
|
||||
return OneVsAllGeneric(binary_quantifier, n_jobs, parallel_backend)
|
||||
|
||||
|
||||
class OneVsAllGeneric(OneVsAll,BaseQuantifier):
|
||||
"""
|
||||
Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
|
||||
quantifier for each class, and then l1-normalizes the outputs so that the class prevelence values sum up to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, binary_quantifier, n_jobs=None):
|
||||
def __init__(self, binary_quantifier, n_jobs=None, parallel_backend='loky'):
|
||||
assert isinstance(binary_quantifier, BaseQuantifier), \
|
||||
f'{binary_quantifier} does not seem to be a Quantifier'
|
||||
if isinstance(binary_quantifier, qp.method.aggregative.AggregativeQuantifier):
|
||||
print('[warning] the quantifier seems to be an instance of qp.method.aggregative.AggregativeQuantifier; '
|
||||
f'you might prefer instantiating {qp.method.aggregative.OneVsAllAggregative.__name__}')
|
||||
self.binary_quantifier = binary_quantifier
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
self.parallel_backend = parallel_backend
|
||||
|
||||
def fit(self, data: LabelledCollection, **kwargs):
|
||||
assert not data.binary, \
|
||||
f'{self.__class__.__name__} expect non-binary data'
|
||||
self.class_quatifier = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
|
||||
Parallel(n_jobs=self.n_jobs, backend='threading')(
|
||||
delayed(self._delayed_binary_fit)(c, self.class_quatifier, data, **kwargs) for c in data.classes_
|
||||
)
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
assert not data.binary, f'{self.__class__.__name__} expect non-binary data'
|
||||
assert fit_classifier == True, 'fit_classifier must be True'
|
||||
|
||||
self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
|
||||
self._parallel(self._delayed_binary_fit, data)
|
||||
return self
|
||||
|
||||
def quantify(self, X, *args):
|
||||
prevalences = np.asarray(
|
||||
Parallel(n_jobs=self.n_jobs, backend='threading')(
|
||||
delayed(self._delayed_binary_predict)(c, self.class_quatifier, X) for c in self.classes
|
||||
def _parallel(self, func, *args, **kwargs):
|
||||
return np.asarray(
|
||||
Parallel(n_jobs=self.n_jobs, backend=self.parallel_backend)(
|
||||
delayed(func)(c, *args, **kwargs) for c in self.classes_
|
||||
)
|
||||
)
|
||||
return F.normalize_prevalence(prevalences)
|
||||
|
||||
def quantify(self, instances):
|
||||
prevalences = self._parallel(self._delayed_binary_predict, instances)
|
||||
return qp.functional.normalize_prevalence(prevalences)
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
return sorted(self.class_quatifier.keys())
|
||||
|
||||
def set_params(self, **parameters):
|
||||
self.binary_quantifier.set_params(**parameters)
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return self.binary_quantifier.get_params()
|
||||
|
||||
def _delayed_binary_predict(self, c, quantifiers, X):
|
||||
return quantifiers[c].quantify(X)[:, 1] # the mean is the estimation for the positive class prevalence
|
||||
|
||||
def _delayed_binary_fit(self, c, quantifiers, data, **kwargs):
|
||||
bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
|
||||
quantifiers[c].fit(bindata, **kwargs)
|
||||
def classes_(self):
|
||||
return sorted(self.dict_binary_quantifiers.keys())
|
||||
|
||||
def _delayed_binary_predict(self, c, X):
|
||||
return self.dict_binary_quantifiers[c].quantify(X)[1]
|
||||
|
||||
def _delayed_binary_fit(self, c, data):
|
||||
bindata = LabelledCollection(data.instances, data.labels == c, classes=[False, True])
|
||||
self.dict_binary_quantifiers[c].fit(bindata)
|
||||
|
|
|
@ -7,7 +7,7 @@ from quapy.protocol import APP, NPP, USimplexPP, DomainMixer, AbstractStochastic
|
|||
def mock_labelled_collection(prefix=''):
|
||||
y = [0] * 250 + [1] * 250 + [2] * 250 + [3] * 250
|
||||
X = [prefix + str(i) + '-' + str(yi) for i, yi in enumerate(y)]
|
||||
return LabelledCollection(X, y, classes_=sorted(np.unique(y)))
|
||||
return LabelledCollection(X, y, classes=sorted(np.unique(y)))
|
||||
|
||||
|
||||
def samples_to_str(protocol):
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import unittest
|
||||
import quapy as qp
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.functional import strprev
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
from method.aggregative import PACC
|
||||
from quapy.method.aggregative import PACC
|
||||
|
||||
|
||||
class MyTestCase(unittest.TestCase):
|
||||
def test_replicability(self):
|
||||
def test_prediction_replicability(self):
|
||||
|
||||
dataset = qp.datasets.fetch_UCIDataset('yeast')
|
||||
|
||||
|
@ -25,6 +26,53 @@ class MyTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEqual(str_prev1, str_prev2) # add assertion here
|
||||
|
||||
def test_samping_replicability(self):
|
||||
import numpy as np
|
||||
|
||||
def equal_collections(c1, c2, value=True):
|
||||
self.assertEqual(np.all(c1.X == c2.X), value)
|
||||
self.assertEqual(np.all(c1.y == c2.y), value)
|
||||
if value:
|
||||
self.assertEqual(np.all(c1.classes_ == c2.classes_), value)
|
||||
|
||||
X = list(map(str, range(100)))
|
||||
y = np.random.randint(0, 2, 100)
|
||||
data = LabelledCollection(instances=X, labels=y)
|
||||
|
||||
sample1 = data.sampling(50)
|
||||
sample2 = data.sampling(50)
|
||||
equal_collections(sample1, sample2, False)
|
||||
|
||||
sample1 = data.sampling(50, random_state=0)
|
||||
sample2 = data.sampling(50, random_state=0)
|
||||
equal_collections(sample1, sample2, True)
|
||||
|
||||
sample1 = data.sampling(50, *[0.7, 0.3], random_state=0)
|
||||
sample2 = data.sampling(50, *[0.7, 0.3], random_state=0)
|
||||
equal_collections(sample1, sample2, True)
|
||||
|
||||
with qp.util.temp_seed(0):
|
||||
sample1 = data.sampling(50, *[0.7, 0.3])
|
||||
with qp.util.temp_seed(0):
|
||||
sample2 = data.sampling(50, *[0.7, 0.3])
|
||||
equal_collections(sample1, sample2, True)
|
||||
|
||||
sample1 = data.sampling(50, *[0.7, 0.3], random_state=0)
|
||||
sample2 = data.sampling(50, *[0.7, 0.3], random_state=0)
|
||||
equal_collections(sample1, sample2, True)
|
||||
|
||||
sample1_tr, sample1_te = data.split_stratified(train_prop=0.7, random_state=0)
|
||||
sample2_tr, sample2_te = data.split_stratified(train_prop=0.7, random_state=0)
|
||||
equal_collections(sample1_tr, sample2_tr, True)
|
||||
equal_collections(sample1_te, sample2_te, True)
|
||||
|
||||
with qp.util.temp_seed(0):
|
||||
sample1_tr, sample1_te = data.split_stratified(train_prop=0.7)
|
||||
with qp.util.temp_seed(0):
|
||||
sample2_tr, sample2_te = data.split_stratified(train_prop=0.7)
|
||||
equal_collections(sample1_tr, sample2_tr, True)
|
||||
equal_collections(sample1_te, sample2_te, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -73,14 +73,16 @@ def temp_seed(random_state):
|
|||
|
||||
:param random_state: the seed to set within the "with" context
|
||||
"""
|
||||
state = np.random.get_state()
|
||||
#save the seed just in case is needed (for instance for setting the seed to child processes)
|
||||
qp.environ['_R_SEED'] = random_state
|
||||
np.random.seed(random_state)
|
||||
if random_state is not None:
|
||||
state = np.random.get_state()
|
||||
#save the seed just in case is needed (for instance for setting the seed to child processes)
|
||||
qp.environ['_R_SEED'] = random_state
|
||||
np.random.seed(random_state)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
np.random.set_state(state)
|
||||
if random_state is not None:
|
||||
np.random.set_state(state)
|
||||
|
||||
|
||||
def download_file(url, archive_filename):
|
||||
|
|
Loading…
Reference in New Issue