diff --git a/examples/one_vs_all_example.py b/examples/one_vs_all_example.py new file mode 100644 index 0000000..7488199 --- /dev/null +++ b/examples/one_vs_all_example.py @@ -0,0 +1,53 @@ +import quapy as qp +from quapy.method.aggregative import MS2, OneVsAllAggregative, OneVsAllGeneric +from quapy.method.base import getOneVsAll +from quapy.model_selection import GridSearchQ +from quapy.protocol import USimplexPP +from sklearn.linear_model import LogisticRegression +import numpy as np + +""" +In this example, we will create a quantifier for tweet sentiment analysis considering three classes: negative, neutral, +and positive. We will use a one-vs-all approach using a binary quantifier for demonstration purposes. +""" + +qp.environ['SAMPLE_SIZE'] = 100 + +""" +Any binary quantifier can be turned into a single-label quantifier by means of getOneVsAll function. +This function returns an instance of OneVsAll quantifier. Actually, it either returns the subclass OneVsAllGeneric +when the quantifier is an instance of BaseQuantifier, and it returns OneVsAllAggregative when the quantifier is +an instance of AggregativeQuantifier. Although OneVsAllGeneric works in all cases, using OneVsAllAggregative has +some additional advantages (namely, all the advantages that AggregativeQuantifiers enjoy, i.e., faster predictions +during evaluation). +""" +quantifier = getOneVsAll(MS2(LogisticRegression()), parallel_backend="loky") +print(f'the quantifier is an instance of {quantifier.__class__.__name__}') + +# load a ternary dataset +train_modsel, val = qp.datasets.fetch_twitter('hcr', for_model_selection=True, pickle=True).train_test + +""" +model selection: for this example, we are relying on the USimplexPP protocol, i.e., a variant of the +artificial-prevalence protocol that generates random samples (100 in this case) for randomly picked priors +from the unit simplex. The priors are sampled using the Kraemer algorithm. Note this is in contrast to the +standard APP protocol, that instead explores a prefixed grid of prevalence values. +""" +param_grid = { + 'binary_quantifier__classifier__C': np.logspace(-2,2,5), # classifier-dependent hyperparameter + 'binary_quantifier__classifier__class_weight': ['balanced', None] # classifier-dependent hyperparameter +} +print('starting model selection') +gs = GridSearchQ(quantifier, param_grid, protocol=USimplexPP(val), n_jobs=-1, verbose=True, refit=False) +quantifier = gs.fit(train_modsel).best_model() + +print('training on the whole training set') +train, test = qp.datasets.fetch_twitter('hcr', for_model_selection=False, pickle=True).train_test +quantifier.fit(train) + +# evaluation +mae = qp.evaluation.evaluate(quantifier, protocol=USimplexPP(test), error_metric='mae') + +print(f'MAE = {mae:.4f}') + + diff --git a/quapy/CHANGE_LOG.txt b/quapy/CHANGE_LOG.txt index 3fb21f6..98c939f 100644 --- a/quapy/CHANGE_LOG.txt +++ b/quapy/CHANGE_LOG.txt @@ -83,7 +83,6 @@ Things to fix: - update unit tests - update Wikis... - Resolve the OneVsAll thing (it is in base.py and in aggregative.py) -- Add a proper log? - improve plots - documentation of protocols is incomplete diff --git a/quapy/classification/calibration.py b/quapy/classification/calibration.py index f35bb97..a3f1543 100644 --- a/quapy/classification/calibration.py +++ b/quapy/classification/calibration.py @@ -65,7 +65,7 @@ class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabi """ Fits the calibration in a cross-validation manner, i.e., it generates posterior probabilities for all training instances via cross-validation, and then retrains the classifier on all training instances. - The posterior probabilities thus generated are used for calibrating the outpus of the classifier. + The posterior probabilities thus generated are used for calibrating the outputs of the classifier. :param X: array-like of shape `(n_samples, n_features)` with the data instances :param y: array-like of shape `(n_samples,)` with the class labels diff --git a/quapy/data/base.py b/quapy/data/base.py index 62f871d..7093821 100644 --- a/quapy/data/base.py +++ b/quapy/data/base.py @@ -6,21 +6,22 @@ from scipy.sparse import vstack from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold from numpy.random import RandomState from quapy.functional import strprev +from quapy.util import temp_seed class LabelledCollection: """ - A LabelledCollection is a set of objects each with a label associated to it. This class implements many sampling - routines. - + A LabelledCollection is a set of objects each with a label attached to each of them. + This class implements several sampling routines and other utilities. + :param instances: array-like (np.ndarray, list, or csr_matrix are supported) :param labels: array-like with the same length of instances - :param classes_: optional, list of classes from which labels are taken. If not specified, the classes are inferred + :param classes: optional, list of classes from which labels are taken. If not specified, the classes are inferred from the labels. The classes must be indicated in cases in which some of the labels might have no examples (i.e., a prevalence of 0) """ - def __init__(self, instances, labels, classes_=None): + def __init__(self, instances, labels, classes=None): if issparse(instances): self.instances = instances elif isinstance(instances, list) and len(instances) > 0 and isinstance(instances[0], str): @@ -30,14 +31,14 @@ class LabelledCollection: self.instances = np.asarray(instances) self.labels = np.asarray(labels) n_docs = len(self) - if classes_ is None: + if classes is None: self.classes_ = np.unique(self.labels) self.classes_.sort() else: - self.classes_ = np.unique(np.asarray(classes_)) + self.classes_ = np.unique(np.asarray(classes)) self.classes_.sort() - if len(set(self.labels).difference(set(classes_))) > 0: - raise ValueError(f'labels ({set(self.labels)}) contain values not included in classes_ ({set(classes_)})') + if len(set(self.labels).difference(set(classes))) > 0: + raise ValueError(f'labels ({set(self.labels)}) contain values not included in classes_ ({set(classes)})') self.index = {class_: np.arange(n_docs)[self.labels == class_] for class_ in self.classes_} @classmethod @@ -101,7 +102,7 @@ class LabelledCollection: """ return self.n_classes == 2 - def sampling_index(self, size, *prevs, shuffle=True): + def sampling_index(self, size, *prevs, shuffle=True, random_state=None): """ Returns an index to be used to extract a random sample of desired size and desired prevalence values. If the prevalence values are not specified, then returns the index of a uniform sampling. @@ -113,10 +114,11 @@ class LabelledCollection: it is constrained. E.g., for binary collections, only the prevalence `p` for the first class (as listed in `self.classes_` can be specified, while the other class takes prevalence value `1-p` :param shuffle: if set to True (default), shuffles the index before returning it + :param random_state: seed for reproducing sampling :return: a np.ndarray of shape `(size)` with the indexes """ if len(prevs) == 0: # no prevalence was indicated; returns an index for uniform sampling - return self.uniform_sampling_index(size) + return self.uniform_sampling_index(size, random_state=random_state) if len(prevs) == self.n_classes - 1: prevs = prevs + (1 - sum(prevs),) assert len(prevs) == self.n_classes, 'unexpected number of prevalences' @@ -129,22 +131,23 @@ class LabelledCollection: # (This aims at avoiding the remainder to be placed in a class for which the prevalence requested is 0.) n_requests = {class_: int(size * prevs[i]) for i, class_ in enumerate(self.classes_)} remainder = size - sum(n_requests.values()) - for rand_class in np.random.choice(self.classes_, size=remainder, p=prevs): - n_requests[rand_class] += 1 + with temp_seed(random_state): + for rand_class in np.random.choice(self.classes_, size=remainder, p=prevs): + n_requests[rand_class] += 1 - indexes_sample = [] - for class_, n_requested in n_requests.items(): - n_candidates = len(self.index[class_]) - index_sample = self.index[class_][ - np.random.choice(n_candidates, size=n_requested, replace=(n_requested > n_candidates)) - ] if n_requested > 0 else [] + indexes_sample = [] + for class_, n_requested in n_requests.items(): + n_candidates = len(self.index[class_]) + index_sample = self.index[class_][ + np.random.choice(n_candidates, size=n_requested, replace=(n_requested > n_candidates)) + ] if n_requested > 0 else [] - indexes_sample.append(index_sample) + indexes_sample.append(index_sample) - indexes_sample = np.concatenate(indexes_sample).astype(int) + indexes_sample = np.concatenate(indexes_sample).astype(int) - if shuffle: - indexes_sample = np.random.permutation(indexes_sample) + if shuffle: + indexes_sample = np.random.permutation(indexes_sample) return indexes_sample @@ -164,7 +167,7 @@ class LabelledCollection: ng = np.random return ng.choice(len(self), size, replace=size > len(self)) - def sampling(self, size, *prevs, shuffle=True): + def sampling(self, size, *prevs, shuffle=True, random_state=None): """ Return a random sample (an instance of :class:`LabelledCollection`) of desired size and desired prevalence values. For each class, the sampling is drawn without replacement if the requested prevalence is larger than @@ -175,10 +178,11 @@ class LabelledCollection: it is constrained. E.g., for binary collections, only the prevalence `p` for the first class (as listed in `self.classes_` can be specified, while the other class takes prevalence value `1-p` :param shuffle: if set to True (default), shuffles the index before returning it + :param random_state: seed for reproducing sampling :return: an instance of :class:`LabelledCollection` with length == `size` and prevalence close to `prevs` (or prevalence == `prevs` if the exact prevalence values can be met as proportions of instances) """ - prev_index = self.sampling_index(size, *prevs, shuffle=shuffle) + prev_index = self.sampling_index(size, *prevs, shuffle=shuffle, random_state=random_state) return self.sampling_from_index(prev_index) def uniform_sampling(self, size, random_state=None): @@ -204,7 +208,7 @@ class LabelledCollection: """ documents = self.instances[index] labels = self.labels[index] - return LabelledCollection(documents, labels, classes_=self.classes_) + return LabelledCollection(documents, labels, classes=self.classes_) def split_stratified(self, train_prop=0.6, random_state=None): """ @@ -221,11 +225,10 @@ class LabelledCollection: tr_docs, te_docs, tr_labels, te_labels = train_test_split( self.instances, self.labels, train_size=train_prop, stratify=self.labels, random_state=random_state ) - training = LabelledCollection(tr_docs, tr_labels, classes_=self.classes_) - test = LabelledCollection(te_docs, te_labels, classes_=self.classes_) + training = LabelledCollection(tr_docs, tr_labels, classes=self.classes_) + test = LabelledCollection(te_docs, te_labels, classes=self.classes_) return training, test - def split_random(self, train_prop=0.6, random_state=None): """ Returns two instances of :class:`LabelledCollection` split randomly from this collection, at desired @@ -261,20 +264,33 @@ class LabelledCollection: :return: a :class:`LabelledCollection` representing the union of both collections """ if not all(np.sort(self.classes_)==np.sort(other.classes_)): - raise NotImplementedError('unsupported operation for collections on different classes') + raise NotImplementedError(f'unsupported operation for collections on different classes; ' + f'expected {self.classes_}, found {other.classes_}') + return LabelledCollection.mix(self, other) - if other is None: - return self - elif issparse(self.instances) and issparse(other.instances): - join_instances = vstack([self.instances, other.instances]) - elif isinstance(self.instances, list) and isinstance(other.instances, list): - join_instances = self.instances + other.instances - elif isinstance(self.instances, np.ndarray) and isinstance(other.instances, np.ndarray): - join_instances = np.concatenate([self.instances, other.instances]) + @classmethod + def mix(cls, a:'LabelledCollection', b:'LabelledCollection'): + """ + Returns a new :class:`LabelledCollection` as the union of this collection with another collection. + + :param a: instance of :class:`LabelledCollection` + :param b: instance of :class:`LabelledCollection` + :return: a :class:`LabelledCollection` representing the union of both collections + """ + if a is None: return b + if b is None: return a + elif issparse(a.instances) and issparse(b.instances): + join_instances = vstack([a.instances, b.instances]) + elif isinstance(a.instances, list) and isinstance(b.instances, list): + join_instances = a.instances + b.instances + elif isinstance(a.instances, np.ndarray) and isinstance(b.instances, np.ndarray): + join_instances = np.concatenate([a.instances, b.instances]) else: raise NotImplementedError('unsupported operation for collection types') - labels = np.concatenate([self.labels, other.labels]) - return LabelledCollection(join_instances, labels, classes_=self.classes_) + labels = np.concatenate([a.labels, b.labels]) + classes = np.unique(np.concatenate([a.classes_, b.classes_])).sort() + return LabelledCollection(join_instances, labels, classes=classes) + @property def Xy(self): @@ -291,7 +307,7 @@ class LabelledCollection: def Xp(self): """ Gets the instances and the true prevalence. This is useful when implementing evaluation protocols from - a `LabelledCollection` object. + a :class:`LabelledCollection` object. :return: a tuple `(instances, prevalence)` from this collection """ @@ -357,7 +373,7 @@ class LabelledCollection: f'#classes={stats_["classes"]}, prevs={stats_["prevs"]}') return stats_ - def kFCV(self, nfolds=5, nrepeats=1, random_state=0): + def kFCV(self, nfolds=5, nrepeats=1, random_state=None): """ Generator of stratified folds to be used in k-fold cross validation. diff --git a/quapy/data/datasets.py b/quapy/data/datasets.py index 241cd04..5c5eb99 100644 --- a/quapy/data/datasets.py +++ b/quapy/data/datasets.py @@ -6,6 +6,7 @@ import os import zipfile from os.path import join import pandas as pd +import scipy from quapy.data.base import Dataset, LabelledCollection from quapy.data.preprocessing import text2tfidf, reduce_columns diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index a9a93cb..e07f665 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -14,7 +14,7 @@ import quapy.functional as F from classification.calibration import NBVSCalibration, BCTSCalibration, TSCalibration, VSCalibration from quapy.classification.svmperf import SVMperf from quapy.data import LabelledCollection -from quapy.method.base import BaseQuantifier, BinaryQuantifier +from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric # Abstract classes @@ -1246,7 +1246,7 @@ MedianSweep = MS MedianSweep2 = MS2 -class OneVsAll(AggregativeQuantifier): +class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier): """ Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary quantifier for each class, and then l1-normalizes the outputs so that the @@ -1257,25 +1257,19 @@ class OneVsAll(AggregativeQuantifier): :param binary_quantifier: a quantifier (binary) that will be employed to work on multiclass model in a one-vs-all manner :param n_jobs: number of parallel workers + :param parallel_backend: the parallel backend for joblib (default "loky"); this is helpful for some quantifiers + (e.g., ELM-based ones) that cannot be run with multiprocessing, since the temp dir they create during fit will + is removed and no longer available at predict time. """ - def __init__(self, binary_quantifier, n_jobs=None): - assert isinstance(self.binary_quantifier, BaseQuantifier), \ + def __init__(self, binary_quantifier, n_jobs=None, parallel_backend='loky'): + assert isinstance(binary_quantifier, BaseQuantifier), \ f'{self.binary_quantifier} does not seem to be a Quantifier' - assert isinstance(self.binary_quantifier, AggregativeQuantifier), \ + assert isinstance(binary_quantifier, AggregativeQuantifier), \ f'{self.binary_quantifier} does not seem to be of type Aggregative' self.binary_quantifier = binary_quantifier self.n_jobs = qp._get_njobs(n_jobs) - - def fit(self, data: LabelledCollection, fit_classifier=True): - assert not data.binary, \ - f'{self.__class__.__name__} expect non-binary data' - assert fit_classifier == True, \ - 'fit_classifier must be True' - - self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_} - self.__parallel(self._delayed_binary_fit, data) - return self + self.parallel_backend = parallel_backend def classify(self, instances): """ @@ -1292,35 +1286,16 @@ class OneVsAll(AggregativeQuantifier): :return: `np.ndarray` """ - classif_predictions = self.__parallel(self._delayed_binary_classification, instances) + classif_predictions = self._parallel(self._delayed_binary_classification, instances) if isinstance(self.binary_quantifier, AggregativeProbabilisticQuantifier): return np.swapaxes(classif_predictions, 0, 1) else: return classif_predictions.T def aggregate(self, classif_predictions): - prevalences = self.__parallel(self._delayed_binary_aggregate, classif_predictions) + prevalences = self._parallel(self._delayed_binary_aggregate, classif_predictions) return F.normalize_prevalence(prevalences) - def __parallel(self, func, *args, **kwargs): - return np.asarray( - # some quantifiers (in particular, ELM-based ones) cannot be run with multiprocess, since the temp dir they - # create during the fit will be removed and be no longer available for the predict... - Parallel(n_jobs=self.n_jobs, backend='threading')( - delayed(func)(c, *args, **kwargs) for c in self.classes_ - ) - ) - - @property - def classes_(self): - return sorted(self.dict_binary_quantifiers.keys()) - - def set_params(self, **parameters): - self.binary_quantifier.set_params(**parameters) - - def get_params(self, deep=True): - return self.binary_quantifier.get_params() - def _delayed_binary_classification(self, c, X): return self.dict_binary_quantifiers[c].classify(X) @@ -1328,7 +1303,3 @@ class OneVsAll(AggregativeQuantifier): # the estimation for the positive class prevalence return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1] - def _delayed_binary_fit(self, c, data): - bindata = LabelledCollection(data.instances, data.labels == c, classes_=[False, True]) - self.dict_binary_quantifiers[c].fit(bindata) - diff --git a/quapy/method/base.py b/quapy/method/base.py index a80f7b7..1803085 100644 --- a/quapy/method/base.py +++ b/quapy/method/base.py @@ -1,10 +1,12 @@ from abc import ABCMeta, abstractmethod from copy import deepcopy +from joblib import Parallel, delayed from sklearn.base import BaseEstimator import quapy as qp from quapy.data import LabelledCollection +import numpy as np # Base Quantifier abstract class @@ -48,50 +50,61 @@ class BinaryQuantifier(BaseQuantifier): f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.' -class OneVsAllGeneric: +class OneVsAll: + pass + + +def getOneVsAll(binary_quantifier, n_jobs=None, parallel_backend='loky'): + assert isinstance(binary_quantifier, BaseQuantifier), \ + f'{binary_quantifier} does not seem to be a Quantifier' + if isinstance(binary_quantifier, qp.method.aggregative.AggregativeQuantifier): + return qp.method.aggregative.OneVsAllAggregative(binary_quantifier, n_jobs, parallel_backend) + else: + return OneVsAllGeneric(binary_quantifier, n_jobs, parallel_backend) + + +class OneVsAllGeneric(OneVsAll,BaseQuantifier): """ Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary quantifier for each class, and then l1-normalizes the outputs so that the class prevelence values sum up to 1. """ - def __init__(self, binary_quantifier, n_jobs=None): + def __init__(self, binary_quantifier, n_jobs=None, parallel_backend='loky'): assert isinstance(binary_quantifier, BaseQuantifier), \ f'{binary_quantifier} does not seem to be a Quantifier' + if isinstance(binary_quantifier, qp.method.aggregative.AggregativeQuantifier): + print('[warning] the quantifier seems to be an instance of qp.method.aggregative.AggregativeQuantifier; ' + f'you might prefer instantiating {qp.method.aggregative.OneVsAllAggregative.__name__}') self.binary_quantifier = binary_quantifier self.n_jobs = qp._get_njobs(n_jobs) + self.parallel_backend = parallel_backend - def fit(self, data: LabelledCollection, **kwargs): - assert not data.binary, \ - f'{self.__class__.__name__} expect non-binary data' - self.class_quatifier = {c: deepcopy(self.binary_quantifier) for c in data.classes_} - Parallel(n_jobs=self.n_jobs, backend='threading')( - delayed(self._delayed_binary_fit)(c, self.class_quatifier, data, **kwargs) for c in data.classes_ - ) + def fit(self, data: LabelledCollection, fit_classifier=True): + assert not data.binary, f'{self.__class__.__name__} expect non-binary data' + assert fit_classifier == True, 'fit_classifier must be True' + + self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_} + self._parallel(self._delayed_binary_fit, data) return self - def quantify(self, X, *args): - prevalences = np.asarray( - Parallel(n_jobs=self.n_jobs, backend='threading')( - delayed(self._delayed_binary_predict)(c, self.class_quatifier, X) for c in self.classes + def _parallel(self, func, *args, **kwargs): + return np.asarray( + Parallel(n_jobs=self.n_jobs, backend=self.parallel_backend)( + delayed(func)(c, *args, **kwargs) for c in self.classes_ ) ) - return F.normalize_prevalence(prevalences) + + def quantify(self, instances): + prevalences = self._parallel(self._delayed_binary_predict, instances) + return qp.functional.normalize_prevalence(prevalences) @property - def classes(self): - return sorted(self.class_quatifier.keys()) - - def set_params(self, **parameters): - self.binary_quantifier.set_params(**parameters) - - def get_params(self, deep=True): - return self.binary_quantifier.get_params() - - def _delayed_binary_predict(self, c, quantifiers, X): - return quantifiers[c].quantify(X)[:, 1] # the mean is the estimation for the positive class prevalence - - def _delayed_binary_fit(self, c, quantifiers, data, **kwargs): - bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2) - quantifiers[c].fit(bindata, **kwargs) + def classes_(self): + return sorted(self.dict_binary_quantifiers.keys()) + def _delayed_binary_predict(self, c, X): + return self.dict_binary_quantifiers[c].quantify(X)[1] + def _delayed_binary_fit(self, c, data): + bindata = LabelledCollection(data.instances, data.labels == c, classes=[False, True]) + self.dict_binary_quantifiers[c].fit(bindata) diff --git a/quapy/tests/test_protocols.py b/quapy/tests/test_protocols.py index 1510fee..e5d446e 100644 --- a/quapy/tests/test_protocols.py +++ b/quapy/tests/test_protocols.py @@ -7,7 +7,7 @@ from quapy.protocol import APP, NPP, USimplexPP, DomainMixer, AbstractStochastic def mock_labelled_collection(prefix=''): y = [0] * 250 + [1] * 250 + [2] * 250 + [3] * 250 X = [prefix + str(i) + '-' + str(yi) for i, yi in enumerate(y)] - return LabelledCollection(X, y, classes_=sorted(np.unique(y))) + return LabelledCollection(X, y, classes=sorted(np.unique(y))) def samples_to_str(protocol): diff --git a/quapy/tests/test_replicability.py b/quapy/tests/test_replicability.py index 329ac32..e89531a 100644 --- a/quapy/tests/test_replicability.py +++ b/quapy/tests/test_replicability.py @@ -1,13 +1,14 @@ import unittest import quapy as qp +from quapy.data import LabelledCollection from quapy.functional import strprev from sklearn.linear_model import LogisticRegression -from method.aggregative import PACC +from quapy.method.aggregative import PACC class MyTestCase(unittest.TestCase): - def test_replicability(self): + def test_prediction_replicability(self): dataset = qp.datasets.fetch_UCIDataset('yeast') @@ -25,6 +26,53 @@ class MyTestCase(unittest.TestCase): self.assertEqual(str_prev1, str_prev2) # add assertion here + def test_samping_replicability(self): + import numpy as np + + def equal_collections(c1, c2, value=True): + self.assertEqual(np.all(c1.X == c2.X), value) + self.assertEqual(np.all(c1.y == c2.y), value) + if value: + self.assertEqual(np.all(c1.classes_ == c2.classes_), value) + + X = list(map(str, range(100))) + y = np.random.randint(0, 2, 100) + data = LabelledCollection(instances=X, labels=y) + + sample1 = data.sampling(50) + sample2 = data.sampling(50) + equal_collections(sample1, sample2, False) + + sample1 = data.sampling(50, random_state=0) + sample2 = data.sampling(50, random_state=0) + equal_collections(sample1, sample2, True) + + sample1 = data.sampling(50, *[0.7, 0.3], random_state=0) + sample2 = data.sampling(50, *[0.7, 0.3], random_state=0) + equal_collections(sample1, sample2, True) + + with qp.util.temp_seed(0): + sample1 = data.sampling(50, *[0.7, 0.3]) + with qp.util.temp_seed(0): + sample2 = data.sampling(50, *[0.7, 0.3]) + equal_collections(sample1, sample2, True) + + sample1 = data.sampling(50, *[0.7, 0.3], random_state=0) + sample2 = data.sampling(50, *[0.7, 0.3], random_state=0) + equal_collections(sample1, sample2, True) + + sample1_tr, sample1_te = data.split_stratified(train_prop=0.7, random_state=0) + sample2_tr, sample2_te = data.split_stratified(train_prop=0.7, random_state=0) + equal_collections(sample1_tr, sample2_tr, True) + equal_collections(sample1_te, sample2_te, True) + + with qp.util.temp_seed(0): + sample1_tr, sample1_te = data.split_stratified(train_prop=0.7) + with qp.util.temp_seed(0): + sample2_tr, sample2_te = data.split_stratified(train_prop=0.7) + equal_collections(sample1_tr, sample2_tr, True) + equal_collections(sample1_te, sample2_te, True) + if __name__ == '__main__': unittest.main() diff --git a/quapy/util.py b/quapy/util.py index 6f8543d..298f02a 100644 --- a/quapy/util.py +++ b/quapy/util.py @@ -73,14 +73,16 @@ def temp_seed(random_state): :param random_state: the seed to set within the "with" context """ - state = np.random.get_state() - #save the seed just in case is needed (for instance for setting the seed to child processes) - qp.environ['_R_SEED'] = random_state - np.random.seed(random_state) + if random_state is not None: + state = np.random.get_state() + #save the seed just in case is needed (for instance for setting the seed to child processes) + qp.environ['_R_SEED'] = random_state + np.random.seed(random_state) try: yield finally: - np.random.set_state(state) + if random_state is not None: + np.random.set_state(state) def download_file(url, archive_filename):