Skip to content

Commit

Permalink
Fix typo in tests; caught on GPU and TPU (jax-ml#2902)
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula authored Apr 30, 2020
1 parent b39da1f commit 8d4b685
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/multibackend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@ def my_sin(x): return np.sin(x)

# jit with `device` spec places the data on the specified device
result3 = api.jit(my_sin, device=cpus[0])(2)
self.assertEqual(result1.device_buffer.device(), cpus[0])
self.assertEqual(result3.device_buffer.device(), cpus[0])

# jit with `backend` spec places the data on the specified backend
result4 = api.jit(my_sin, backend="cpu")(2)
self.assertEqual(result1.device_buffer.device(), cpus[0])
self.assertEqual(result4.device_buffer.device(), cpus[0])


if __name__ == "__main__":
Expand Down

0 comments on commit 8d4b685

Please sign in to comment.