diff --git a/jax/experimental/optimizers.py b/jax/experimental/optimizers.py index c4164f4a71d5..20348d9269bb 100644 --- a/jax/experimental/optimizers.py +++ b/jax/experimental/optimizers.py @@ -499,7 +499,7 @@ def piecewise_constant(boundaries, values): if not boundaries.ndim == values.ndim == 1: raise ValueError("boundaries and values must be sequences") if not boundaries.shape[0] == values.shape[0] - 1: - raise ValueError("boundaries length must be one longer than values length") + raise ValueError("boundaries length must be one shorter than values length") def schedule(i): return values[jnp.sum(i > boundaries)]