1
0
Fork 0

pytests before release

This commit is contained in:
Alejandro Moreo Fernandez 2024-02-14 12:27:19 +01:00
parent 7705c92c8c
commit 40cb8f78fe
6 changed files with 33 additions and 14 deletions

View File

@ -118,14 +118,15 @@ def _prevalence_report(true_prevs, estim_prevs, error_metrics: Iterable[Union[st
assert all(hasattr(e, '__call__') for e in error_funcs), 'invalid error functions'
error_names = [e.__name__ for e in error_funcs]
df = pd.DataFrame(columns=['true-prev', 'estim-prev'] + error_names)
row_entries = []
for true_prev, estim_prev in zip(true_prevs, estim_prevs):
series = {'true-prev': true_prev, 'estim-prev': estim_prev}
for error_name, error_metric in zip(error_names, error_funcs):
score = error_metric(true_prev, estim_prev)
series[error_name] = score
df = df.append(series, ignore_index=True)
row_entries.append(series)
df = pd.DataFrame.from_records(row_entries)
return df

View File

@ -122,7 +122,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
raise ValueError(f'proportion {predict_on=} out of range, must be in (0,1)')
train, val = data.split_stratified(train_prop=(1 - predict_on))
self.classifier.fit(*train.Xy)
predictions = LabelledCollection(self.classify(val.Xtr), val.ytr, classes=data.classes_)
predictions = LabelledCollection(self.classify(val.X), val.y, classes=data.classes_)
else:
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
f'the set on which predictions have to be issued must be '
@ -604,6 +604,13 @@ class EMQ(AggregativeSoftQuantifier):
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
f'{self.exact_train_prev=} and {self.recalib=}. This has no effect and causes an unnecessary '
f'overload.')
else:
if self.recalib is not None:
print(f'[warning] The parameter {self.recalib=} requires the val_split be different from None. '
f'This parameter will be set to 5. To avoid this warning, set this value to a float value '
f'indicating the proportion of training data to be used as validation, or to an integer '
f'indicating the number of folds for kFCV.')
self.val_split=5
def classify(self, instances):
"""

View File

@ -327,7 +327,7 @@ class GridSearchQ(BaseQuantifier):
if self.raise_errors:
raise exception
else:
return ConfigStatus(params, status, str(e))
return ConfigStatus(params, status)
try:
with timeout(self.timeout):
@ -336,13 +336,13 @@ class GridSearchQ(BaseQuantifier):
status = ConfigStatus(params, Status.SUCCESS)
except TimeoutError as e:
status = _handle(Status.TIMEOUT, str(e))
status = _handle(Status.TIMEOUT, e)
except ValueError as e:
status = _handle(Status.INVALID, str(e))
status = _handle(Status.INVALID, e)
except Exception as e:
status = _handle(Status.ERROR, str(e))
status = _handle(Status.ERROR, e)
took = time() - tinit
return output, status, took
@ -364,7 +364,7 @@ def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfol
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train)
fold_prev = quantifier.quantify(test.Xtr)
fold_prev = quantifier.quantify(test.X)
rel_size = 1. * len(test) / len(data)
total_prev += fold_prev*rel_size

View File

@ -97,11 +97,22 @@ class ModselTestCase(unittest.TestCase):
param_grid = {'classifier__C': np.logspace(-3,3,7)}
app = APP(validation, sample_size=100, random_state=1)
q = GridSearchQ(
q, param_grid, protocol=app, error='mae', refit=True, timeout=3, n_jobs=-1, verbose=True
print('Expecting TimeoutError to be raised')
modsel = GridSearchQ(
q, param_grid, protocol=app, timeout=3, n_jobs=-1, verbose=True, raise_errors=True
)
with self.assertRaises(TimeoutError):
q.fit(training)
modsel.fit(training)
print('Expecting ValueError to be raised')
modsel = GridSearchQ(
q, param_grid, protocol=app, timeout=3, n_jobs=-1, verbose=True, raise_errors=False
)
with self.assertRaises(ValueError):
# this exception is not raised because of the timeout, but because no combination of hyperparams
# succedded (in this case, a ValueError is raised, regardless of "raise_errors"
modsel.fit(training)
if __name__ == '__main__':

View File

@ -32,8 +32,8 @@ class MyTestCase(unittest.TestCase):
def test_samping_replicability(self):
def equal_collections(c1, c2, value=True):
self.assertEqual(np.all(c1.Xtr == c2.Xtr), value)
self.assertEqual(np.all(c1.ytr == c2.ytr), value)
self.assertEqual(np.all(c1.X == c2.X), value)
self.assertEqual(np.all(c1.y == c2.y), value)
if value:
self.assertEqual(np.all(c1.classes_ == c2.classes_), value)

View File

@ -113,7 +113,7 @@ setup(
python_requires='>=3.8, <4',
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention'],
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo'],
# List additional groups of dependencies here (e.g. development
# dependencies). Users will be able to install these using the "extras"