From e9d62f1f2a308e077bb9506cb9f5ae0617b3933b Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Wed, 17 May 2023 14:05:27 +0200 Subject: [PATCH] hp fix --- quacc/main.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/quacc/main.py b/quacc/main.py index 83fd6ee..557e093 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -87,8 +87,12 @@ def extend_and_quantify( pred_prob_test = model.predict_proba(test.X) _test = extend_collection(test, pred_prob_test) _estim_prev = q_model.quantify(_test.instances) - # TODO: check that _estim_prev has all the classes and eventually fill the - # missing ones with 0 + # check that _estim_prev has all the classes and eventually fill the missing + # ones with 0 + for _cls in _test.classes_: + if _cls not in q_model.classes_: + _estim_prev = np.insert(_estim_prev, _cls, [0.0], axis=0) + print(_estim_prev) return _test.prevalence(), _estim_prev if isinstance(test, LabelledCollection):