Skip to content

Commit

Permalink
Update Python tests to reflect JAX changes
Browse files Browse the repository at this point in the history
JAX no longer accepts lists in many of its NumPy API functions, but
casting them to arrays explicitly fixes the issue.
  • Loading branch information
apaszke committed Sep 29, 2021
1 parent d3ca6fc commit b80d02d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/tests/jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def test_jit_scale(self):

def test_vmap(self):
add_two = primitive(dex.eval(r'\x:((Fin 2)=>Float). for i. x.i + 2.0'))
x = jnp.linspace([0, 3], [5, 8], num=4, dtype=jnp.float32)
x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32)
np.testing.assert_allclose(jax.vmap(add_two)(x), x + 2.0)

def test_vmap_nonzero_index(self):
add_two = primitive(dex.eval(r'\x:((Fin 4)=>Float). for i. x.i + 2.0'))
x = jnp.linspace([0, 3], [5, 8], num=4, dtype=jnp.float32)
x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32)
np.testing.assert_allclose(
jax.vmap(add_two, in_axes=1, out_axes=1)(x), x + 2.0)

Expand All @@ -72,12 +72,12 @@ def test_vmap_unbatched_array(self):

def test_vmap_jit(self):
add_two = primitive(dex.eval(r'\x:((Fin 2)=>Float). for i. x.i + 2.0'))
x = jnp.linspace([0, 3], [5, 8], num=4, dtype=jnp.float32)
x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32)
np.testing.assert_allclose(jax.jit(jax.vmap(add_two))(x), x + 2.0)

def test_vmap_jit_nonzero_index(self):
add_two = primitive(dex.eval(r'\x:((Fin 4)=>Float). for i. x.i + 2.0'))
x = jnp.linspace([0, 3], [5, 8], num=4, dtype=jnp.float32)
x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32)
np.testing.assert_allclose(
jax.jit(jax.vmap(add_two, in_axes=1, out_axes=1))(x), x + 2.0)

Expand Down

0 comments on commit b80d02d

Please sign in to comment.