Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ScatterNdFunctor to update operands of all valid indices and contin…
Browse files Browse the repository at this point in the history
…ue on bad indices.

This is to support the new attribute "bad_indices_policy". Passing downs the behavior also works, but it makes `ScatterNdFunctor` unnecessarily complicated while the only gain is the performance with out-of-bound error.

PiperOrigin-RevId: 637646021
TF2JAXDev authored and TF2JAXDev committed May 28, 2024
1 parent ad9afbb commit b31dab2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
@@ -1727,7 +1727,7 @@ def _func(
@register_operation("ScatterNd")
def _scatter_nd(proto):
"""Parse a ScatterNd op."""
_check_attrs(proto, {"T", "Tindices"})
_check_attrs(proto, {"T", "Tindices", "bad_indices_policy"})

def _func(
indices: jnp.ndarray,

0 comments on commit b31dab2

Please sign in to comment.