forked from moreo/QuaPy
evaluation updated
This commit is contained in:
parent
c608647475
commit
25a829996e
|
@ -2,7 +2,7 @@ import quapy as qp
|
||||||
from quapy.method.aggregative import newELM
|
from quapy.method.aggregative import newELM
|
||||||
from quapy.method.base import newOneVsAll
|
from quapy.method.base import newOneVsAll
|
||||||
from quapy.model_selection import GridSearchQ
|
from quapy.model_selection import GridSearchQ
|
||||||
from quapy.protocol import USimplexPP
|
from quapy.protocol import UPP
|
||||||
|
|
||||||
"""
|
"""
|
||||||
In this example, we will show hoy to define a quantifier based on explicit loss minimization (ELM).
|
In this example, we will show hoy to define a quantifier based on explicit loss minimization (ELM).
|
||||||
|
@ -57,7 +57,7 @@ param_grid = {
|
||||||
'binary_quantifier__classifier__C': [0.01, 1, 100], # classifier-dependent hyperparameter
|
'binary_quantifier__classifier__C': [0.01, 1, 100], # classifier-dependent hyperparameter
|
||||||
}
|
}
|
||||||
print('starting model selection')
|
print('starting model selection')
|
||||||
model_selection = GridSearchQ(quantifier, param_grid, protocol=USimplexPP(val), verbose=True, refit=False)
|
model_selection = GridSearchQ(quantifier, param_grid, protocol=UPP(val), verbose=True, refit=False)
|
||||||
quantifier = model_selection.fit(train_modsel).best_model()
|
quantifier = model_selection.fit(train_modsel).best_model()
|
||||||
|
|
||||||
print('training on the whole training set')
|
print('training on the whole training set')
|
||||||
|
@ -65,7 +65,7 @@ train, test = qp.datasets.fetch_twitter('hcr', for_model_selection=False, pickle
|
||||||
quantifier.fit(train)
|
quantifier.fit(train)
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
mae = qp.evaluation.evaluate(quantifier, protocol=USimplexPP(test), error_metric='mae')
|
mae = qp.evaluation.evaluate(quantifier, protocol=UPP(test), error_metric='mae')
|
||||||
|
|
||||||
print(f'MAE = {mae:.4f}')
|
print(f'MAE = {mae:.4f}')
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import quapy as qp
|
||||||
from quapy.method.aggregative import MS2
|
from quapy.method.aggregative import MS2
|
||||||
from quapy.method.base import newOneVsAll
|
from quapy.method.base import newOneVsAll
|
||||||
from quapy.model_selection import GridSearchQ
|
from quapy.model_selection import GridSearchQ
|
||||||
from quapy.protocol import USimplexPP
|
from quapy.protocol import UPP
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ print(f'the quantifier is an instance of {quantifier.__class__.__name__}')
|
||||||
train_modsel, val = qp.datasets.fetch_twitter('hcr', for_model_selection=True, pickle=True).train_test
|
train_modsel, val = qp.datasets.fetch_twitter('hcr', for_model_selection=True, pickle=True).train_test
|
||||||
|
|
||||||
"""
|
"""
|
||||||
model selection: for this example, we are relying on the USimplexPP protocol, i.e., a variant of the
|
model selection: for this example, we are relying on the UPP protocol, i.e., a variant of the
|
||||||
artificial-prevalence protocol that generates random samples (100 in this case) for randomly picked priors
|
artificial-prevalence protocol that generates random samples (100 in this case) for randomly picked priors
|
||||||
from the unit simplex. The priors are sampled using the Kraemer algorithm. Note this is in contrast to the
|
from the unit simplex. The priors are sampled using the Kraemer algorithm. Note this is in contrast to the
|
||||||
standard APP protocol, that instead explores a prefixed grid of prevalence values.
|
standard APP protocol, that instead explores a prefixed grid of prevalence values.
|
||||||
|
@ -39,7 +39,7 @@ param_grid = {
|
||||||
'binary_quantifier__classifier__class_weight': ['balanced', None] # classifier-dependent hyperparameter
|
'binary_quantifier__classifier__class_weight': ['balanced', None] # classifier-dependent hyperparameter
|
||||||
}
|
}
|
||||||
print('starting model selection')
|
print('starting model selection')
|
||||||
model_selection = GridSearchQ(quantifier, param_grid, protocol=USimplexPP(val), verbose=True, refit=False)
|
model_selection = GridSearchQ(quantifier, param_grid, protocol=UPP(val), verbose=True, refit=False)
|
||||||
quantifier = model_selection.fit(train_modsel).best_model()
|
quantifier = model_selection.fit(train_modsel).best_model()
|
||||||
|
|
||||||
print('training on the whole training set')
|
print('training on the whole training set')
|
||||||
|
@ -47,7 +47,7 @@ train, test = qp.datasets.fetch_twitter('hcr', for_model_selection=False, pickle
|
||||||
quantifier.fit(train)
|
quantifier.fit(train)
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
mae = qp.evaluation.evaluate(quantifier, protocol=USimplexPP(test), error_metric='mae')
|
mae = qp.evaluation.evaluate(quantifier, protocol=UPP(test), error_metric='mae')
|
||||||
|
|
||||||
print(f'MAE = {mae:.4f}')
|
print(f'MAE = {mae:.4f}')
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ Change Log 0.1.7
|
||||||
|
|
||||||
- Protocols are now abstracted as instances of AbstractProtocol. There is a new class extending AbstractProtocol called
|
- Protocols are now abstracted as instances of AbstractProtocol. There is a new class extending AbstractProtocol called
|
||||||
AbstractStochasticSeededProtocol, which implements a seeding policy to allow replicate the series of samplings.
|
AbstractStochasticSeededProtocol, which implements a seeding policy to allow replicate the series of samplings.
|
||||||
There are some examples of protocols, APP, NPP, USimplexPP, DomainMixer (experimental).
|
There are some examples of protocols, APP, NPP, UPP, DomainMixer (experimental).
|
||||||
The idea is to start the sampling by simply calling the __call__ method.
|
The idea is to start the sampling by simply calling the __call__ method.
|
||||||
This change has a great impact in the framework, since many functions in qp.evaluation, qp.model_selection,
|
This change has a great impact in the framework, since many functions in qp.evaluation, qp.model_selection,
|
||||||
and sampling functions in LabelledCollection relied of the old functions. E.g., the functionality of
|
and sampling functions in LabelledCollection relied of the old functions. E.g., the functionality of
|
||||||
|
|
|
@ -211,11 +211,13 @@ def __check_eps(eps=None):
|
||||||
|
|
||||||
CLASSIFICATION_ERROR = {f1e, acce}
|
CLASSIFICATION_ERROR = {f1e, acce}
|
||||||
QUANTIFICATION_ERROR = {mae, mrae, mse, mkld, mnkld}
|
QUANTIFICATION_ERROR = {mae, mrae, mse, mkld, mnkld}
|
||||||
|
QUANTIFICATION_ERROR_SINGLE = {ae, rae, se, kld, nkld}
|
||||||
QUANTIFICATION_ERROR_SMOOTH = {kld, nkld, rae, mkld, mnkld, mrae}
|
QUANTIFICATION_ERROR_SMOOTH = {kld, nkld, rae, mkld, mnkld, mrae}
|
||||||
CLASSIFICATION_ERROR_NAMES = {func.__name__ for func in CLASSIFICATION_ERROR}
|
CLASSIFICATION_ERROR_NAMES = {func.__name__ for func in CLASSIFICATION_ERROR}
|
||||||
QUANTIFICATION_ERROR_NAMES = {func.__name__ for func in QUANTIFICATION_ERROR}
|
QUANTIFICATION_ERROR_NAMES = {func.__name__ for func in QUANTIFICATION_ERROR}
|
||||||
|
QUANTIFICATION_ERROR_SINGLE_NAMES = {func.__name__ for func in QUANTIFICATION_ERROR_SINGLE}
|
||||||
QUANTIFICATION_ERROR_SMOOTH_NAMES = {func.__name__ for func in QUANTIFICATION_ERROR_SMOOTH}
|
QUANTIFICATION_ERROR_SMOOTH_NAMES = {func.__name__ for func in QUANTIFICATION_ERROR_SMOOTH}
|
||||||
ERROR_NAMES = CLASSIFICATION_ERROR_NAMES | QUANTIFICATION_ERROR_NAMES
|
ERROR_NAMES = CLASSIFICATION_ERROR_NAMES | QUANTIFICATION_ERROR_NAMES | QUANTIFICATION_ERROR_SINGLE_NAMES
|
||||||
|
|
||||||
f1_error = f1e
|
f1_error = f1e
|
||||||
acc_error = acce
|
acc_error = acce
|
||||||
|
|
|
@ -7,7 +7,34 @@ from quapy.method.base import BaseQuantifier
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
def prediction(model: BaseQuantifier, protocol: AbstractProtocol, aggr_speedup='auto', verbose=False):
|
def prediction(
|
||||||
|
model: BaseQuantifier,
|
||||||
|
protocol: AbstractProtocol,
|
||||||
|
aggr_speedup: Union[str, bool] = 'auto',
|
||||||
|
verbose=False):
|
||||||
|
"""
|
||||||
|
Uses a quantification model to generate predictions for the samples generated via a specific protocol.
|
||||||
|
This function is central to all evaluation processes, and is endowed with an optimization to speed-up the
|
||||||
|
prediction of protocols that generate samples from a large collection. The optimization applies to aggregative
|
||||||
|
quantifiers only, and to OnLabelledCollection protocols, and comes down to generating the classification
|
||||||
|
predictions once and for all, and then generating samples over the classification predictions (instead of over
|
||||||
|
the raw instances), so that the classifier prediction is never called again. This behaviour is obtained by
|
||||||
|
setting `aggr_speedup` to 'auto' or True, and is only carried out if the overall process is convenient in terms
|
||||||
|
of computations (e.g., if the number of classification predictions needed for the original collection exceed the
|
||||||
|
number of classification predictions needed for all samples, then the optimization is not undertaken).
|
||||||
|
|
||||||
|
:param model: a quantifier, instance of :class:`quapy.method.base.BaseQuantifier`
|
||||||
|
:param protocol: :class:`quapy.protocol.AbstractProtocol`; if this object is also instance of
|
||||||
|
:class:`quapy.protocol.OnLabelledCollection`, then the aggregation speed-up can be run. This is the protocol
|
||||||
|
in charge of generating the samples for which the model has to issue class prevalence predictions.
|
||||||
|
:param aggr_speedup: whether or not to apply the speed-up. Set to "force" for applying it even if the number of
|
||||||
|
instances in the original collection on which the protocol acts is larger than the number of instances
|
||||||
|
in the samples to be generated. Set to True or "auto" (default) for letting QuaPy decide whether it is
|
||||||
|
convenient or not. Set to False to deactivate.
|
||||||
|
:param verbose: boolean, show or not information in stdout
|
||||||
|
:return: a tuple `(true_prevs, estim_prevs)` in which each element in the tuple is an array of shape
|
||||||
|
`(n_samples, n_classes)` containing the true, or predicted, prevalence values for each sample
|
||||||
|
"""
|
||||||
assert aggr_speedup in [False, True, 'auto', 'force'], 'invalid value for aggr_speedup'
|
assert aggr_speedup in [False, True, 'auto', 'force'], 'invalid value for aggr_speedup'
|
||||||
|
|
||||||
sout = lambda x: print(x) if verbose else None
|
sout = lambda x: print(x) if verbose else None
|
||||||
|
@ -54,8 +81,29 @@ def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=F
|
||||||
def evaluation_report(model: BaseQuantifier,
|
def evaluation_report(model: BaseQuantifier,
|
||||||
protocol: AbstractProtocol,
|
protocol: AbstractProtocol,
|
||||||
error_metrics: Iterable[Union[str,Callable]] = 'mae',
|
error_metrics: Iterable[Union[str,Callable]] = 'mae',
|
||||||
aggr_speedup='auto',
|
aggr_speedup: Union[str, bool] = 'auto',
|
||||||
verbose=False):
|
verbose=False):
|
||||||
|
"""
|
||||||
|
Generates a report (a pandas' DataFrame) containing information of the evaluation of the model as according
|
||||||
|
to a specific protocol and in terms of one or more evaluation metrics (errors).
|
||||||
|
|
||||||
|
|
||||||
|
:param model: a quantifier, instance of :class:`quapy.method.base.BaseQuantifier`
|
||||||
|
:param protocol: :class:`quapy.protocol.AbstractProtocol`; if this object is also instance of
|
||||||
|
:class:`quapy.protocol.OnLabelledCollection`, then the aggregation speed-up can be run. This is the protocol
|
||||||
|
in charge of generating the samples in which the model is evaluated.
|
||||||
|
:param error_metrics: a string, or list of strings, representing the name(s) of an error function in `qp.error`
|
||||||
|
(e.g., 'mae', the default value), or a callable function, or a list of callable functions, implementing
|
||||||
|
the error function itself.
|
||||||
|
:param aggr_speedup: whether or not to apply the speed-up. Set to "force" for applying it even if the number of
|
||||||
|
instances in the original collection on which the protocol acts is larger than the number of instances
|
||||||
|
in the samples to be generated. Set to True or "auto" (default) for letting QuaPy decide whether it is
|
||||||
|
convenient or not. Set to False to deactivate.
|
||||||
|
:param verbose: boolean, show or not information in stdout
|
||||||
|
:return: a pandas' DataFrame containing the columns 'true-prev' (the true prevalence of each sample),
|
||||||
|
'estim-prev' (the prevalence estimated by the model for each sample), and as many columns as error metrics
|
||||||
|
have been indicated, each displaying the score in terms of that metric for every sample.
|
||||||
|
"""
|
||||||
|
|
||||||
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
|
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
|
||||||
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
|
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
|
||||||
|
@ -85,8 +133,27 @@ def evaluate(
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
protocol: AbstractProtocol,
|
protocol: AbstractProtocol,
|
||||||
error_metric: Union[str, Callable],
|
error_metric: Union[str, Callable],
|
||||||
aggr_speedup='auto',
|
aggr_speedup: Union[str, bool] = 'auto',
|
||||||
verbose=False):
|
verbose=False):
|
||||||
|
"""
|
||||||
|
Evaluates a quantification model according to a specific sample generation protocol and in terms of one
|
||||||
|
evaluation metric (error).
|
||||||
|
|
||||||
|
:param model: a quantifier, instance of :class:`quapy.method.base.BaseQuantifier`
|
||||||
|
:param protocol: :class:`quapy.protocol.AbstractProtocol`; if this object is also instance of
|
||||||
|
:class:`quapy.protocol.OnLabelledCollection`, then the aggregation speed-up can be run. This is the protocol
|
||||||
|
in charge of generating the samples in which the model is evaluated.
|
||||||
|
:param error_metric: a string representing the name(s) of an error function in `qp.error`
|
||||||
|
(e.g., 'mae'), or a callable function implementing the error function itself.
|
||||||
|
:param aggr_speedup: whether or not to apply the speed-up. Set to "force" for applying it even if the number of
|
||||||
|
instances in the original collection on which the protocol acts is larger than the number of instances
|
||||||
|
in the samples to be generated. Set to True or "auto" (default) for letting QuaPy decide whether it is
|
||||||
|
convenient or not. Set to False to deactivate.
|
||||||
|
:param verbose: boolean, show or not information in stdout
|
||||||
|
:return: if the error metric is not averaged (e.g., 'ae', 'rae'), returns an array of shape `(n_samples,)` with
|
||||||
|
the error scores for each sample; if the error metric is averaged (e.g., 'mae', 'mrae') then returns
|
||||||
|
a single float
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(error_metric, str):
|
if isinstance(error_metric, str):
|
||||||
error_metric = qp.error.from_name(error_metric)
|
error_metric = qp.error.from_name(error_metric)
|
||||||
|
@ -96,9 +163,21 @@ def evaluate(
|
||||||
|
|
||||||
def evaluate_on_samples(
|
def evaluate_on_samples(
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
samples: [qp.data.LabelledCollection],
|
samples: Iterable[qp.data.LabelledCollection],
|
||||||
error_metric: Union[str, Callable],
|
error_metric: Union[str, Callable],
|
||||||
verbose=False):
|
verbose=False):
|
||||||
|
"""
|
||||||
|
Evaluates a quantification model on a given set of samples and in terms of one evaluation metric (error).
|
||||||
|
|
||||||
|
:param model: a quantifier, instance of :class:`quapy.method.base.BaseQuantifier`
|
||||||
|
:param samples: a list of samples on which the quantifier is to be evaluated
|
||||||
|
:param error_metric: a string representing the name(s) of an error function in `qp.error`
|
||||||
|
(e.g., 'mae'), or a callable function implementing the error function itself.
|
||||||
|
:param verbose: boolean, show or not information in stdout
|
||||||
|
:return: if the error metric is not averaged (e.g., 'ae', 'rae'), returns an array of shape `(n_samples,)` with
|
||||||
|
the error scores for each sample; if the error metric is averaged (e.g., 'mae', 'mrae') then returns
|
||||||
|
a single float
|
||||||
|
"""
|
||||||
|
|
||||||
return evaluate(model, IterateProtocol(samples), error_metric, aggr_speedup=False, verbose=verbose)
|
return evaluate(model, IterateProtocol(samples), error_metric, aggr_speedup=False, verbose=verbose)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from torch.nn import MSELoss
|
from torch.nn import MSELoss
|
||||||
from torch.nn.functional import relu
|
from torch.nn.functional import relu
|
||||||
|
|
||||||
from protocol import USimplexPP
|
from protocol import UPP
|
||||||
from quapy.method.aggregative import *
|
from quapy.method.aggregative import *
|
||||||
from quapy.util import EarlyStop
|
from quapy.util import EarlyStop
|
||||||
|
|
||||||
|
@ -218,7 +218,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
self.quanet.train(mode=train)
|
self.quanet.train(mode=train)
|
||||||
losses = []
|
losses = []
|
||||||
mae_errors = []
|
mae_errors = []
|
||||||
sampler = USimplexPP(
|
sampler = UPP(
|
||||||
data,
|
data,
|
||||||
sample_size=self.sample_size,
|
sample_size=self.sample_size,
|
||||||
repeats=iterations,
|
repeats=iterations,
|
||||||
|
|
|
@ -327,7 +327,7 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||||
return self.repeats
|
return self.repeats
|
||||||
|
|
||||||
|
|
||||||
class USimplexPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
class UPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||||
"""
|
"""
|
||||||
A variant of :class:`APP` that, instead of using a grid of equidistant prevalence values,
|
A variant of :class:`APP` that, instead of using a grid of equidistant prevalence values,
|
||||||
relies on the Kraemer algorithm for sampling unit (k-1)-simplex uniformly at random, with
|
relies on the Kraemer algorithm for sampling unit (k-1)-simplex uniformly at random, with
|
||||||
|
@ -348,7 +348,7 @@ class USimplexPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol)
|
||||||
|
|
||||||
def __init__(self, data: LabelledCollection, sample_size=None, repeats=100, random_state=0,
|
def __init__(self, data: LabelledCollection, sample_size=None, repeats=100, random_state=0,
|
||||||
return_type='sample_prev'):
|
return_type='sample_prev'):
|
||||||
super(USimplexPP, self).__init__(random_state)
|
super(UPP, self).__init__(random_state)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.sample_size = qp._get_sample_size(sample_size)
|
self.sample_size = qp._get_sample_size(sample_size)
|
||||||
self.repeats = repeats
|
self.repeats = repeats
|
||||||
|
@ -357,9 +357,9 @@ class USimplexPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol)
|
||||||
|
|
||||||
def samples_parameters(self):
|
def samples_parameters(self):
|
||||||
"""
|
"""
|
||||||
Return all the necessary parameters to replicate the samples as according to the USimplexPP protocol.
|
Return all the necessary parameters to replicate the samples as according to the UPP protocol.
|
||||||
|
|
||||||
:return: a list of indexes that realize the USimplexPP sampling
|
:return: a list of indexes that realize the UPP sampling
|
||||||
"""
|
"""
|
||||||
indexes = []
|
indexes = []
|
||||||
for prevs in F.uniform_simplex_sampling(n_classes=self.data.n_classes, size=self.repeats):
|
for prevs in F.uniform_simplex_sampling(n_classes=self.data.n_classes, size=self.repeats):
|
||||||
|
@ -474,3 +474,8 @@ class DomainMixer(AbstractStochasticSeededProtocol):
|
||||||
return self.repeats * len(self.mixture_points)
|
return self.repeats * len(self.mixture_points)
|
||||||
|
|
||||||
|
|
||||||
|
# aliases
|
||||||
|
|
||||||
|
ArtificialPrevalenceProtocol = APP
|
||||||
|
NaturalPrevalenceProtocol = NPP
|
||||||
|
UniformPrevalenceProtocol = UPP
|
|
@ -1,8 +1,14 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from time import time
|
from time import time
|
||||||
from quapy.method.aggregative import EMQ
|
|
||||||
|
from error import QUANTIFICATION_ERROR_SINGLE, QUANTIFICATION_ERROR, QUANTIFICATION_ERROR_NAMES, \
|
||||||
|
QUANTIFICATION_ERROR_SINGLE_NAMES
|
||||||
|
from quapy.method.aggregative import EMQ, PCC
|
||||||
from quapy.method.base import BaseQuantifier
|
from quapy.method.base import BaseQuantifier
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,6 +54,31 @@ class EvalTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(tend_no_optim>(tend_optim/2), True)
|
self.assertEqual(tend_no_optim>(tend_optim/2), True)
|
||||||
|
|
||||||
|
def test_evaluation_output(self):
|
||||||
|
|
||||||
|
data = qp.datasets.fetch_reviews('hp', tfidf=True, min_df=10, pickle=True)
|
||||||
|
train, test = data.training, data.test
|
||||||
|
|
||||||
|
qp.environ['SAMPLE_SIZE']=100
|
||||||
|
|
||||||
|
protocol = qp.protocol.APP(test, random_state=0)
|
||||||
|
|
||||||
|
q = PCC(LogisticRegression()).fit(train)
|
||||||
|
|
||||||
|
single_errors = list(QUANTIFICATION_ERROR_SINGLE_NAMES)
|
||||||
|
averaged_errors = ['m'+e for e in single_errors]
|
||||||
|
single_errors = single_errors + [qp.error.from_name(e) for e in single_errors]
|
||||||
|
averaged_errors = averaged_errors + [qp.error.from_name(e) for e in averaged_errors]
|
||||||
|
for error_metric, averaged_error_metric in zip(single_errors, averaged_errors):
|
||||||
|
score = qp.evaluation.evaluate(q, protocol, error_metric=averaged_error_metric)
|
||||||
|
self.assertTrue(isinstance(score, float))
|
||||||
|
|
||||||
|
scores = qp.evaluation.evaluate(q, protocol, error_metric=error_metric)
|
||||||
|
self.assertTrue(isinstance(scores, np.ndarray))
|
||||||
|
|
||||||
|
self.assertEqual(scores.mean(), score)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.protocol import APP, NPP, USimplexPP, DomainMixer, AbstractStochasticSeededProtocol
|
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
|
||||||
|
|
||||||
|
|
||||||
def mock_labelled_collection(prefix=''):
|
def mock_labelled_collection(prefix=''):
|
||||||
|
@ -102,14 +102,14 @@ class TestProtocols(unittest.TestCase):
|
||||||
|
|
||||||
def test_kraemer_replicate(self):
|
def test_kraemer_replicate(self):
|
||||||
data = mock_labelled_collection()
|
data = mock_labelled_collection()
|
||||||
p = USimplexPP(data, sample_size=5, repeats=10, random_state=42)
|
p = UPP(data, sample_size=5, repeats=10, random_state=42)
|
||||||
|
|
||||||
samples1 = samples_to_str(p)
|
samples1 = samples_to_str(p)
|
||||||
samples2 = samples_to_str(p)
|
samples2 = samples_to_str(p)
|
||||||
|
|
||||||
self.assertEqual(samples1, samples2)
|
self.assertEqual(samples1, samples2)
|
||||||
|
|
||||||
p = USimplexPP(data, sample_size=5, repeats=10) # <- random_state is by default set to 0
|
p = UPP(data, sample_size=5, repeats=10) # <- random_state is by default set to 0
|
||||||
|
|
||||||
samples1 = samples_to_str(p)
|
samples1 = samples_to_str(p)
|
||||||
samples2 = samples_to_str(p)
|
samples2 = samples_to_str(p)
|
||||||
|
@ -118,7 +118,7 @@ class TestProtocols(unittest.TestCase):
|
||||||
|
|
||||||
def test_kraemer_not_replicate(self):
|
def test_kraemer_not_replicate(self):
|
||||||
data = mock_labelled_collection()
|
data = mock_labelled_collection()
|
||||||
p = USimplexPP(data, sample_size=5, repeats=10, random_state=None)
|
p = UPP(data, sample_size=5, repeats=10, random_state=None)
|
||||||
|
|
||||||
samples1 = samples_to_str(p)
|
samples1 = samples_to_str(p)
|
||||||
samples2 = samples_to_str(p)
|
samples2 = samples_to_str(p)
|
||||||
|
|
Loading…
Reference in New Issue