From da02fc61c792431e2ecd5720cc5c93e1439a79cc Mon Sep 17 00:00:00 2001 From: TF2JAXDev Date: Mon, 27 May 2024 08:15:25 -0700 Subject: [PATCH] `ScatterNdFunctor` to update operands of all valid indices and continue on bad indices. This is to support the new attribute "bad_indices_policy". Passing downs the behavior also works, but it makes `ScatterNdFunctor` unnecessarily complicated while the only gain is the performance with out-of-bound error. For testing, the existing "Error_OutOfBound" test is modified so it can also ensure that the bad index in the middle still correctly raise error and the caller still handles it. PiperOrigin-RevId: 637646021 --- tf2jax/_src/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index a41839a..d511f7e 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -1727,7 +1727,7 @@ def _func( @register_operation("ScatterNd") def _scatter_nd(proto): """Parse a ScatterNd op.""" - _check_attrs(proto, {"T", "Tindices"}) + _check_attrs(proto, {"T", "Tindices", "bad_indices_policy"}) def _func( indices: jnp.ndarray,