From 14326b2122e7720bbd3dd33d07c0f98cbd2a4196 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Sun, 5 Nov 2023 14:15:15 +0100 Subject: [PATCH] ref baseline fixed --- quacc/evaluation/baseline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py index 9a5fc5d..c51351e 100644 --- a/quacc/evaluation/baseline.py +++ b/quacc/evaluation/baseline.py @@ -65,11 +65,10 @@ def ref( validation: LabelledCollection, protocol: AbstractStochasticSeededProtocol, ): - c_model_predict = getattr(c_model, "predict_proba") + c_model_predict = getattr(c_model, "predict") report = EvaluationReport(name="ref") for test in protocol(): - test_probs = c_model_predict(test.X) - test_preds = np.argmax(test_probs, axis=-1) + test_preds = c_model_predict(test.X) report.append_row( test.prevalence(), acc_score=metrics.accuracy_score(test.y, test_preds),