diff --git a/quacc/main.py b/quacc/main.py index 83fd6ee..557e093 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -87,8 +87,12 @@ def extend_and_quantify( pred_prob_test = model.predict_proba(test.X) _test = extend_collection(test, pred_prob_test) _estim_prev = q_model.quantify(_test.instances) - # TODO: check that _estim_prev has all the classes and eventually fill the - # missing ones with 0 + # check that _estim_prev has all the classes and eventually fill the missing + # ones with 0 + for _cls in _test.classes_: + if _cls not in q_model.classes_: + _estim_prev = np.insert(_estim_prev, _cls, [0.0], axis=0) + print(_estim_prev) return _test.prevalence(), _estim_prev if isinstance(test, LabelledCollection):