ref baseline fixed

This commit is contained in:
Lorenzo Volpi 2023-11-05 14:15:15 +01:00
parent b96432f87b
commit 14326b2122
1 changed files with 2 additions and 3 deletions

View File

@ -65,11 +65,10 @@ def ref(
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
):
c_model_predict = getattr(c_model, "predict_proba")
c_model_predict = getattr(c_model, "predict")
report = EvaluationReport(name="ref")
for test in protocol():
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
test_preds = c_model_predict(test.X)
report.append_row(
test.prevalence(),
acc_score=metrics.accuracy_score(test.y, test_preds),