ref baseline fixed
This commit is contained in:
parent
b96432f87b
commit
14326b2122
|
@ -65,11 +65,10 @@ def ref(
|
||||||
validation: LabelledCollection,
|
validation: LabelledCollection,
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
):
|
):
|
||||||
c_model_predict = getattr(c_model, "predict_proba")
|
c_model_predict = getattr(c_model, "predict")
|
||||||
report = EvaluationReport(name="ref")
|
report = EvaluationReport(name="ref")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
test_probs = c_model_predict(test.X)
|
test_preds = c_model_predict(test.X)
|
||||||
test_preds = np.argmax(test_probs, axis=-1)
|
|
||||||
report.append_row(
|
report.append_row(
|
||||||
test.prevalence(),
|
test.prevalence(),
|
||||||
acc_score=metrics.accuracy_score(test.y, test_preds),
|
acc_score=metrics.accuracy_score(test.y, test_preds),
|
||||||
|
|
Loading…
Reference in New Issue