forked from moreo/QuaPy
bugfix in APP
This commit is contained in:
parent
b882c23477
commit
2d12ce12b9
|
@ -46,8 +46,8 @@ with qp.util.temp_seed(0):
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
model = OLD_GridSearchQ(
|
# model = OLD_GridSearchQ(
|
||||||
# model = qp.model_selection.GridSearchQ(
|
model = qp.model_selection.GridSearchQ(
|
||||||
model=model,
|
model=model,
|
||||||
param_grid=param_grid,
|
param_grid=param_grid,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
|
|
|
@ -257,8 +257,9 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||||
"""
|
"""
|
||||||
dimensions = self.data.n_classes
|
dimensions = self.data.n_classes
|
||||||
s = F.prevalence_linspace(self.n_prevalences, repeats=1, smooth_limits_epsilon=self.smooth_limits_epsilon)
|
s = F.prevalence_linspace(self.n_prevalences, repeats=1, smooth_limits_epsilon=self.smooth_limits_epsilon)
|
||||||
|
eps = (s[1]-s[0])/2 # handling floating rounding
|
||||||
s = [s] * (dimensions - 1)
|
s = [s] * (dimensions - 1)
|
||||||
prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) <= 1.0)]
|
prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) < (1.+eps))]
|
||||||
prevs = np.asarray(prevs).reshape(len(prevs), -1)
|
prevs = np.asarray(prevs).reshape(len(prevs), -1)
|
||||||
if self.repeats > 1:
|
if self.repeats > 1:
|
||||||
prevs = np.repeat(prevs, self.repeats, axis=0)
|
prevs = np.repeat(prevs, self.repeats, axis=0)
|
||||||
|
|
Loading…
Reference in New Issue