From 920380c452e7b83674908f2b6fbfd408a0f042cd Mon Sep 17 00:00:00 2001 From: TF2JAXDev Date: Sun, 26 May 2024 23:33:17 -0700 Subject: [PATCH] Add "bad_indcies_policy" to ScatterNd PiperOrigin-RevId: 637540132 --- tf2jax/_src/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index a41839a..d511f7e 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -1727,7 +1727,7 @@ def _func( @register_operation("ScatterNd") def _scatter_nd(proto): """Parse a ScatterNd op.""" - _check_attrs(proto, {"T", "Tindices"}) + _check_attrs(proto, {"T", "Tindices", "bad_indices_policy"}) def _func( indices: jnp.ndarray,