diff --git a/python/tests/jax_test.py b/python/tests/jax_test.py index abff0a84e..dca5676dd 100644 --- a/python/tests/jax_test.py +++ b/python/tests/jax_test.py @@ -52,12 +52,12 @@ def test_jit_scale(self): def test_vmap(self): add_two = primitive(dex.eval(r'\x:((Fin 2)=>Float). for i. x.i + 2.0')) - x = jnp.linspace([0, 3], [5, 8], num=4, dtype=jnp.float32) + x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32) np.testing.assert_allclose(jax.vmap(add_two)(x), x + 2.0) def test_vmap_nonzero_index(self): add_two = primitive(dex.eval(r'\x:((Fin 4)=>Float). for i. x.i + 2.0')) - x = jnp.linspace([0, 3], [5, 8], num=4, dtype=jnp.float32) + x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32) np.testing.assert_allclose( jax.vmap(add_two, in_axes=1, out_axes=1)(x), x + 2.0) @@ -72,12 +72,12 @@ def test_vmap_unbatched_array(self): def test_vmap_jit(self): add_two = primitive(dex.eval(r'\x:((Fin 2)=>Float). for i. x.i + 2.0')) - x = jnp.linspace([0, 3], [5, 8], num=4, dtype=jnp.float32) + x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32) np.testing.assert_allclose(jax.jit(jax.vmap(add_two))(x), x + 2.0) def test_vmap_jit_nonzero_index(self): add_two = primitive(dex.eval(r'\x:((Fin 4)=>Float). for i. x.i + 2.0')) - x = jnp.linspace([0, 3], [5, 8], num=4, dtype=jnp.float32) + x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32) np.testing.assert_allclose( jax.jit(jax.vmap(add_two, in_axes=1, out_axes=1))(x), x + 2.0)