.. currentmodule:: jax.ops
.. automodule:: jax.ops
JAX is intended to be used with a functional style of programming, and hence does not support NumPy-style indexed assignment directly. Instead, JAX provides pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
.. autosummary:: :toctree: _autosummary index index_update index_add index_mul index_min index_max
JAX also provides an alternate syntax for these indexed update operators.
Specifically, JAX ndarray types have a property at
, which can be used as
follows (where idx
can be an arbitrary index expression).
Alternate syntax | Equivalent expression |
---|---|
x.at[idx].set(y) |
jax.ops.index_update(x, jax.ops.index[idx], y) |
x.at[idx].add(y) |
jax.ops.index_add(x, jax.ops.index[idx], y) |
x.at[idx].mul(y) |
jax.ops.index_mul(x, jax.ops.index[idx], y) |
x.at[idx].min(y) |
jax.ops.index_min(x, jax.ops.index[idx], y) |
x.at[idx].max(y) |
jax.ops.index_max(x, jax.ops.index[idx], y) |
Note that none of these expressions modify the original x; instead they return a modified copy of x.
.. autosummary:: :toctree: _autosummary segment_sum