11import logging
2- import urllib . request
2+ import time
33from urllib .error import URLError
44
55from core .testcontainers .core .container import DockerContainer
6- from core .testcontainers .core .waiting_utils import wait_container_is_ready
7- from core .testcontainers .core .config import testcontainers_config
8- from core .testcontainers .core .waiting_utils import wait_for_logs
6+ from core .testcontainers .core .waiting_utils import wait_container_is_ready , wait_for_logs
97
108class JAXContainer (DockerContainer ):
119 """
@@ -15,16 +13,15 @@ class JAXContainer(DockerContainer):
1513
1614 .. doctest::
1715
18- >>> import jax
1916 >>> from testcontainers.jax import JAXContainer
2017
2118 >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
2219 ... # Connect to the container
2320 ... jax_container.connect()
2421 ...
2522 ... # Run a simple JAX computation
26- ... result = jax.numpy.add(1, 1)
27- ... assert result == 2
23+ ... result = jax_container.run_jax_command("import jax; print(jax .numpy.add(1, 1))" )
24+ ... assert "2" in result.output
2825
2926 .. auto-class:: JAXContainer
3027 :members:
@@ -34,31 +31,49 @@ class JAXContainer(DockerContainer):
3431
3532 def __init__ (self , image = "nvcr.io/nvidia/jax:23.08-py3" , ** kwargs ):
3633 super ().__init__ (image , ** kwargs )
37- self .with_exposed_ports (8888 ) # Expose Jupyter notebook port
3834 self .with_env ("NVIDIA_VISIBLE_DEVICES" , "all" )
3935 self .with_env ("CUDA_VISIBLE_DEVICES" , "all" )
4036 self .with_kwargs (runtime = "nvidia" ) # Use NVIDIA runtime for GPU support
41- self .start_timeout = 600 # 10 minutes
37+ self .start_timeout = 600 # 10 minutes
38+ self .connection_retries = 5
39+ self .connection_retry_delay = 10 # seconds
4240
4341 @wait_container_is_ready (URLError )
4442 def _connect (self ):
45- url = f"http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} "
46- res = urllib .request .urlopen (url , timeout = self .start_timeout )
47- if res .status != 200 :
48- raise Exception (f"Failed to connect to JAX container. Status: { res .status } " )
43+ for attempt in range (self .connection_retries ):
44+ try :
45+ # Check if JAX is properly installed and functioning
46+ result = self .run_jax_command (
47+ "import jax; import jaxlib; "
48+ "print(f'JAX version: {jax.__version__}'); "
49+ "print(f'JAXlib version: {jaxlib.__version__}'); "
50+ "print(f'Available devices: {jax.devices()}'); "
51+ "print(jax.numpy.add(1, 1))"
52+ )
53+
54+ if "JAX version" in result .output and "Available devices" in result .output :
55+ logging .info (f"JAX environment verified:\n { result .output } " )
56+ return True
57+ else :
58+ raise Exception ("JAX environment check failed" )
59+
60+ except Exception as e :
61+ if attempt < self .connection_retries - 1 :
62+ logging .warning (f"Connection attempt { attempt + 1 } failed. Retrying in { self .connection_retry_delay } seconds..." )
63+ time .sleep (self .connection_retry_delay )
64+ else :
65+ raise Exception (f"Failed to connect to JAX container after { self .connection_retries } attempts: { str (e )} " )
66+
67+ return False
4968
5069 def connect (self ):
5170 """
5271 Connect to the JAX container and ensure it's ready.
72+ This method verifies that JAX is properly installed and functioning.
73+ It also checks for available devices, including GPUs if applicable.
5374 """
5475 self ._connect ()
55- logging .info ("Successfully connected to JAX container" )
56-
57- def get_jupyter_url (self ):
58- """
59- Get the URL for accessing the Jupyter notebook server.
60- """
61- return f"http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} "
76+ logging .info ("Successfully connected to JAX container and verified the environment" )
6277
6378 def run_jax_command (self , command ):
6479 """
@@ -68,15 +83,15 @@ def run_jax_command(self, command):
6883 return exec_result
6984
7085 def _wait_for_container_to_be_ready (self ):
71- wait_for_logs (self , "Jupyter Server " , timeout = self .start_timeout )
86+ wait_for_logs (self , "JAX is ready " , timeout = self .start_timeout )
7287
7388 def start (self ):
7489 """
7590 Start the JAX container and wait for it to be ready.
7691 """
7792 super ().start ()
7893 self ._wait_for_container_to_be_ready ()
79- logging .info (f "JAX container started. Jupyter URL: { self . get_jupyter_url () } " )
94+ logging .info ("JAX container started and ready. " )
8095 return self
8196
8297 def stop (self , force = True ):
0 commit comments