11import pytest
22from modules .jax .testcontainers .jax_cuda import JAXContainer
33
4- def test_jax_container ():
5- with JAXContainer () as jax_container :
6- jax_container .connect ()
7-
8- # Test running a simple JAX computation
9- result = jax_container .run_jax_command ("import jax; print(jax.numpy.add(1, 1))" )
10- assert "2" in result .output .decode ()
4+ @pytest .fixture (scope = "module" )
5+ def jax_container ():
6+ with JAXContainer () as container :
7+ container .connect ()
8+ yield container
119
12- def test_jax_container_gpu_support ():
13- with JAXContainer () as jax_container :
14- jax_container .connect ()
15-
16- # Test GPU availability
17- result = jax_container .run_jax_command (
18- "import jax; print(jax.devices())"
19- )
20- assert "gpu" in result .output .decode ().lower ()
10+ def test_jax_container_basic_computation (jax_container ):
11+ result = jax_container .run_jax_command ("import jax; print(jax.numpy.add(1, 1))" )
12+ assert "2" in result .output .decode (), "Basic JAX computation failed"
2113
22- def test_jax_container_jupyter ():
23- with JAXContainer () as jax_container :
24- jax_container .connect ()
25-
26- jupyter_url = jax_container .get_jupyter_url ()
27- assert jupyter_url .startswith ("http://" )
28- assert ":8888" in jupyter_url
14+ def test_jax_container_version (jax_container ):
15+ result = jax_container .run_jax_command ("import jax; print(jax.__version__)" )
16+ assert result .exit_code == 0 , "Failed to get JAX version"
17+ assert result .output .decode ().strip (), "JAX version is empty"
18+
19+ def test_jax_container_gpu_support (jax_container ):
20+ result = jax_container .run_jax_command (
21+ "import jax; devices = jax.devices(); "
22+ "print(any(dev.platform == 'gpu' for dev in devices))"
23+ )
24+ assert "True" in result .output .decode (), "No GPU device found"
25+
26+ def test_jax_container_matrix_multiplication (jax_container ):
27+ command = """
28+ import jax
29+ import jax.numpy as jnp
30+ x = jnp.array([[1, 2], [3, 4]])
31+ y = jnp.array([[5, 6], [7, 8]])
32+ result = jnp.dot(x, y)
33+ print(result)
34+ """
35+ result = jax_container .run_jax_command (command )
36+ assert "[[19 22]\n [43 50]]" in result .output .decode (), "Matrix multiplication failed"
37+
38+ def test_jax_container_custom_image ():
39+ custom_image = "nvcr.io/nvidia/jax:23.09-py3"
40+ with JAXContainer (image = custom_image ) as container :
41+ container .connect ()
42+ result = container .run_jax_command ("import jax; print(jax.__version__)" )
43+ assert result .exit_code == 0 , f"Failed to run JAX with custom image { custom_image } "
0 commit comments