ref baseline fixed

This commit is contained in:
Lorenzo Volpi 2023-11-05 14:15:15 +01:00
parent b96432f87b
commit 14326b2122
1 changed files with 2 additions and 3 deletions

View File

@ -65,11 +65,10 @@ def ref(
validation: LabelledCollection, validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol, protocol: AbstractStochasticSeededProtocol,
): ):
c_model_predict = getattr(c_model, "predict_proba") c_model_predict = getattr(c_model, "predict")
report = EvaluationReport(name="ref") report = EvaluationReport(name="ref")
for test in protocol(): for test in protocol():
test_probs = c_model_predict(test.X) test_preds = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
report.append_row( report.append_row(
test.prevalence(), test.prevalence(),
acc_score=metrics.accuracy_score(test.y, test_preds), acc_score=metrics.accuracy_score(test.y, test_preds),