pytests before release
This commit is contained in:
parent
7705c92c8c
commit
40cb8f78fe
|
@ -118,14 +118,15 @@ def _prevalence_report(true_prevs, estim_prevs, error_metrics: Iterable[Union[st
|
|||
assert all(hasattr(e, '__call__') for e in error_funcs), 'invalid error functions'
|
||||
error_names = [e.__name__ for e in error_funcs]
|
||||
|
||||
df = pd.DataFrame(columns=['true-prev', 'estim-prev'] + error_names)
|
||||
row_entries = []
|
||||
for true_prev, estim_prev in zip(true_prevs, estim_prevs):
|
||||
series = {'true-prev': true_prev, 'estim-prev': estim_prev}
|
||||
for error_name, error_metric in zip(error_names, error_funcs):
|
||||
score = error_metric(true_prev, estim_prev)
|
||||
series[error_name] = score
|
||||
df = df.append(series, ignore_index=True)
|
||||
row_entries.append(series)
|
||||
|
||||
df = pd.DataFrame.from_records(row_entries)
|
||||
return df
|
||||
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
raise ValueError(f'proportion {predict_on=} out of range, must be in (0,1)')
|
||||
train, val = data.split_stratified(train_prop=(1 - predict_on))
|
||||
self.classifier.fit(*train.Xy)
|
||||
predictions = LabelledCollection(self.classify(val.Xtr), val.ytr, classes=data.classes_)
|
||||
predictions = LabelledCollection(self.classify(val.X), val.y, classes=data.classes_)
|
||||
else:
|
||||
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
|
||||
f'the set on which predictions have to be issued must be '
|
||||
|
@ -604,6 +604,13 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
|
||||
f'{self.exact_train_prev=} and {self.recalib=}. This has no effect and causes an unnecessary '
|
||||
f'overload.')
|
||||
else:
|
||||
if self.recalib is not None:
|
||||
print(f'[warning] The parameter {self.recalib=} requires the val_split be different from None. '
|
||||
f'This parameter will be set to 5. To avoid this warning, set this value to a float value '
|
||||
f'indicating the proportion of training data to be used as validation, or to an integer '
|
||||
f'indicating the number of folds for kFCV.')
|
||||
self.val_split=5
|
||||
|
||||
def classify(self, instances):
|
||||
"""
|
||||
|
|
|
@ -327,7 +327,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
if self.raise_errors:
|
||||
raise exception
|
||||
else:
|
||||
return ConfigStatus(params, status, str(e))
|
||||
return ConfigStatus(params, status)
|
||||
|
||||
try:
|
||||
with timeout(self.timeout):
|
||||
|
@ -336,13 +336,13 @@ class GridSearchQ(BaseQuantifier):
|
|||
status = ConfigStatus(params, Status.SUCCESS)
|
||||
|
||||
except TimeoutError as e:
|
||||
status = _handle(Status.TIMEOUT, str(e))
|
||||
status = _handle(Status.TIMEOUT, e)
|
||||
|
||||
except ValueError as e:
|
||||
status = _handle(Status.INVALID, str(e))
|
||||
status = _handle(Status.INVALID, e)
|
||||
|
||||
except Exception as e:
|
||||
status = _handle(Status.ERROR, str(e))
|
||||
status = _handle(Status.ERROR, e)
|
||||
|
||||
took = time() - tinit
|
||||
return output, status, took
|
||||
|
@ -364,7 +364,7 @@ def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfol
|
|||
|
||||
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
|
||||
quantifier.fit(train)
|
||||
fold_prev = quantifier.quantify(test.Xtr)
|
||||
fold_prev = quantifier.quantify(test.X)
|
||||
rel_size = 1. * len(test) / len(data)
|
||||
total_prev += fold_prev*rel_size
|
||||
|
||||
|
|
|
@ -97,11 +97,22 @@ class ModselTestCase(unittest.TestCase):
|
|||
|
||||
param_grid = {'classifier__C': np.logspace(-3,3,7)}
|
||||
app = APP(validation, sample_size=100, random_state=1)
|
||||
q = GridSearchQ(
|
||||
q, param_grid, protocol=app, error='mae', refit=True, timeout=3, n_jobs=-1, verbose=True
|
||||
|
||||
print('Expecting TimeoutError to be raised')
|
||||
modsel = GridSearchQ(
|
||||
q, param_grid, protocol=app, timeout=3, n_jobs=-1, verbose=True, raise_errors=True
|
||||
)
|
||||
with self.assertRaises(TimeoutError):
|
||||
q.fit(training)
|
||||
modsel.fit(training)
|
||||
|
||||
print('Expecting ValueError to be raised')
|
||||
modsel = GridSearchQ(
|
||||
q, param_grid, protocol=app, timeout=3, n_jobs=-1, verbose=True, raise_errors=False
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
# this exception is not raised because of the timeout, but because no combination of hyperparams
|
||||
# succedded (in this case, a ValueError is raised, regardless of "raise_errors"
|
||||
modsel.fit(training)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -32,8 +32,8 @@ class MyTestCase(unittest.TestCase):
|
|||
def test_samping_replicability(self):
|
||||
|
||||
def equal_collections(c1, c2, value=True):
|
||||
self.assertEqual(np.all(c1.Xtr == c2.Xtr), value)
|
||||
self.assertEqual(np.all(c1.ytr == c2.ytr), value)
|
||||
self.assertEqual(np.all(c1.X == c2.X), value)
|
||||
self.assertEqual(np.all(c1.y == c2.y), value)
|
||||
if value:
|
||||
self.assertEqual(np.all(c1.classes_ == c2.classes_), value)
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -113,7 +113,7 @@ setup(
|
|||
|
||||
python_requires='>=3.8, <4',
|
||||
|
||||
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention'],
|
||||
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo'],
|
||||
|
||||
# List additional groups of dependencies here (e.g. development
|
||||
# dependencies). Users will be able to install these using the "extras"
|
||||
|
|
Loading…
Reference in New Issue