diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py index 9a5fc5d..c51351e 100644 --- a/quacc/evaluation/baseline.py +++ b/quacc/evaluation/baseline.py @@ -65,11 +65,10 @@ def ref( validation: LabelledCollection, protocol: AbstractStochasticSeededProtocol, ): - c_model_predict = getattr(c_model, "predict_proba") + c_model_predict = getattr(c_model, "predict") report = EvaluationReport(name="ref") for test in protocol(): - test_probs = c_model_predict(test.X) - test_preds = np.argmax(test_probs, axis=-1) + test_preds = c_model_predict(test.X) report.append_row( test.prevalence(), acc_score=metrics.accuracy_score(test.y, test_preds),