Skip to content

predict_proba in classifier estimators is doing needless assertions #1529

@eddiebergman

Description

@eddiebergman

pred_proba = super().predict_proba(X, batch_size=batch_size, n_jobs=n_jobs)
# Check if all probabilities sum up to 1.
# Assert only if target type is not multilabel-indicator.
if self.target_type not in ["multilabel-indicator"]:
assert np.allclose(
np.sum(pred_proba, axis=1), np.ones_like(pred_proba[:, 0])
), "prediction probability does not sum up to 1!"
# Check that all probability values lie between 0 and 1.
assert (pred_proba >= 0).all() and (
pred_proba <= 1
).all(), "found prediction probability value outside of [0, 1]!"

There's a lot of assertion checking to do here which can really eat into inference time. While the checks are helpful, they seem like they should really be enforced in testing.

Metadata

Metadata

Labels

enhancementA new improvement or feature

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions