1
1
import logging
2
- import urllib . request
2
+ import time
3
3
from urllib .error import URLError
4
4
5
5
from core .testcontainers .core .container import DockerContainer
6
- from core .testcontainers .core .waiting_utils import wait_container_is_ready
7
- from core .testcontainers .core .config import testcontainers_config
8
- from core .testcontainers .core .waiting_utils import wait_for_logs
6
+ from core .testcontainers .core .waiting_utils import wait_container_is_ready , wait_for_logs
9
7
10
8
class JAXContainer (DockerContainer ):
11
9
"""
@@ -15,16 +13,15 @@ class JAXContainer(DockerContainer):
15
13
16
14
.. doctest::
17
15
18
- >>> import jax
19
16
>>> from testcontainers.jax import JAXContainer
20
17
21
18
>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
22
19
... # Connect to the container
23
20
... jax_container.connect()
24
21
...
25
22
... # Run a simple JAX computation
26
- ... result = jax.numpy.add(1, 1)
27
- ... assert result == 2
23
+ ... result = jax_container.run_jax_command("import jax; print(jax .numpy.add(1, 1))" )
24
+ ... assert "2" in result.output
28
25
29
26
.. auto-class:: JAXContainer
30
27
:members:
@@ -34,31 +31,49 @@ class JAXContainer(DockerContainer):
34
31
35
32
def __init__ (self , image = "nvcr.io/nvidia/jax:23.08-py3" , ** kwargs ):
36
33
super ().__init__ (image , ** kwargs )
37
- self .with_exposed_ports (8888 ) # Expose Jupyter notebook port
38
34
self .with_env ("NVIDIA_VISIBLE_DEVICES" , "all" )
39
35
self .with_env ("CUDA_VISIBLE_DEVICES" , "all" )
40
36
self .with_kwargs (runtime = "nvidia" ) # Use NVIDIA runtime for GPU support
41
- self .start_timeout = 600 # 10 minutes
37
+ self .start_timeout = 600 # 10 minutes
38
+ self .connection_retries = 5
39
+ self .connection_retry_delay = 10 # seconds
42
40
43
41
@wait_container_is_ready (URLError )
44
42
def _connect (self ):
45
- url = f"http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} "
46
- res = urllib .request .urlopen (url , timeout = self .start_timeout )
47
- if res .status != 200 :
48
- raise Exception (f"Failed to connect to JAX container. Status: { res .status } " )
43
+ for attempt in range (self .connection_retries ):
44
+ try :
45
+ # Check if JAX is properly installed and functioning
46
+ result = self .run_jax_command (
47
+ "import jax; import jaxlib; "
48
+ "print(f'JAX version: {jax.__version__}'); "
49
+ "print(f'JAXlib version: {jaxlib.__version__}'); "
50
+ "print(f'Available devices: {jax.devices()}'); "
51
+ "print(jax.numpy.add(1, 1))"
52
+ )
53
+
54
+ if "JAX version" in result .output and "Available devices" in result .output :
55
+ logging .info (f"JAX environment verified:\n { result .output } " )
56
+ return True
57
+ else :
58
+ raise Exception ("JAX environment check failed" )
59
+
60
+ except Exception as e :
61
+ if attempt < self .connection_retries - 1 :
62
+ logging .warning (f"Connection attempt { attempt + 1 } failed. Retrying in { self .connection_retry_delay } seconds..." )
63
+ time .sleep (self .connection_retry_delay )
64
+ else :
65
+ raise Exception (f"Failed to connect to JAX container after { self .connection_retries } attempts: { str (e )} " )
66
+
67
+ return False
49
68
50
69
def connect (self ):
51
70
"""
52
71
Connect to the JAX container and ensure it's ready.
72
+ This method verifies that JAX is properly installed and functioning.
73
+ It also checks for available devices, including GPUs if applicable.
53
74
"""
54
75
self ._connect ()
55
- logging .info ("Successfully connected to JAX container" )
56
-
57
- def get_jupyter_url (self ):
58
- """
59
- Get the URL for accessing the Jupyter notebook server.
60
- """
61
- return f"http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} "
76
+ logging .info ("Successfully connected to JAX container and verified the environment" )
62
77
63
78
def run_jax_command (self , command ):
64
79
"""
@@ -68,15 +83,15 @@ def run_jax_command(self, command):
68
83
return exec_result
69
84
70
85
def _wait_for_container_to_be_ready (self ):
71
- wait_for_logs (self , "Jupyter Server " , timeout = self .start_timeout )
86
+ wait_for_logs (self , "JAX is ready " , timeout = self .start_timeout )
72
87
73
88
def start (self ):
74
89
"""
75
90
Start the JAX container and wait for it to be ready.
76
91
"""
77
92
super ().start ()
78
93
self ._wait_for_container_to_be_ready ()
79
- logging .info (f "JAX container started. Jupyter URL: { self . get_jupyter_url () } " )
94
+ logging .info ("JAX container started and ready. " )
80
95
return self
81
96
82
97
def stop (self , force = True ):
0 commit comments