Skip to content

Commit 54f842d

Browse files
committed
add connect method , remove jupyter port connect
1 parent 0b4033e commit 54f842d

File tree

1 file changed

+37
-22
lines changed

1 file changed

+37
-22
lines changed

modules/jax/testcontainers/jax_cuda/__init__.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import logging
2-
import urllib.request
2+
import time
33
from urllib.error import URLError
44

55
from core.testcontainers.core.container import DockerContainer
6-
from core.testcontainers.core.waiting_utils import wait_container_is_ready
7-
from core.testcontainers.core.config import testcontainers_config
8-
from core.testcontainers.core.waiting_utils import wait_for_logs
6+
from core.testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs
97

108
class JAXContainer(DockerContainer):
119
"""
@@ -15,16 +13,15 @@ class JAXContainer(DockerContainer):
1513
1614
.. doctest::
1715
18-
>>> import jax
1916
>>> from testcontainers.jax import JAXContainer
2017
2118
>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
2219
... # Connect to the container
2320
... jax_container.connect()
2421
...
2522
... # Run a simple JAX computation
26-
... result = jax.numpy.add(1, 1)
27-
... assert result == 2
23+
... result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))")
24+
... assert "2" in result.output
2825
2926
.. auto-class:: JAXContainer
3027
:members:
@@ -34,31 +31,49 @@ class JAXContainer(DockerContainer):
3431

3532
def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs):
3633
super().__init__(image, **kwargs)
37-
self.with_exposed_ports(8888) # Expose Jupyter notebook port
3834
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
3935
self.with_env("CUDA_VISIBLE_DEVICES", "all")
4036
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
41-
self.start_timeout = 600 # 10 minutes
37+
self.start_timeout = 600 # 10 minutes
38+
self.connection_retries = 5
39+
self.connection_retry_delay = 10 # seconds
4240

4341
@wait_container_is_ready(URLError)
4442
def _connect(self):
45-
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
46-
res = urllib.request.urlopen(url, timeout=self.start_timeout)
47-
if res.status != 200:
48-
raise Exception(f"Failed to connect to JAX container. Status: {res.status}")
43+
for attempt in range(self.connection_retries):
44+
try:
45+
# Check if JAX is properly installed and functioning
46+
result = self.run_jax_command(
47+
"import jax; import jaxlib; "
48+
"print(f'JAX version: {jax.__version__}'); "
49+
"print(f'JAXlib version: {jaxlib.__version__}'); "
50+
"print(f'Available devices: {jax.devices()}'); "
51+
"print(jax.numpy.add(1, 1))"
52+
)
53+
54+
if "JAX version" in result.output and "Available devices" in result.output:
55+
logging.info(f"JAX environment verified:\n{result.output}")
56+
return True
57+
else:
58+
raise Exception("JAX environment check failed")
59+
60+
except Exception as e:
61+
if attempt < self.connection_retries - 1:
62+
logging.warning(f"Connection attempt {attempt + 1} failed. Retrying in {self.connection_retry_delay} seconds...")
63+
time.sleep(self.connection_retry_delay)
64+
else:
65+
raise Exception(f"Failed to connect to JAX container after {self.connection_retries} attempts: {str(e)}")
66+
67+
return False
4968

5069
def connect(self):
5170
"""
5271
Connect to the JAX container and ensure it's ready.
72+
This method verifies that JAX is properly installed and functioning.
73+
It also checks for available devices, including GPUs if applicable.
5374
"""
5475
self._connect()
55-
logging.info("Successfully connected to JAX container")
56-
57-
def get_jupyter_url(self):
58-
"""
59-
Get the URL for accessing the Jupyter notebook server.
60-
"""
61-
return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
76+
logging.info("Successfully connected to JAX container and verified the environment")
6277

6378
def run_jax_command(self, command):
6479
"""
@@ -68,15 +83,15 @@ def run_jax_command(self, command):
6883
return exec_result
6984

7085
def _wait_for_container_to_be_ready(self):
71-
wait_for_logs(self, "Jupyter Server", timeout=self.start_timeout)
86+
wait_for_logs(self, "JAX is ready", timeout=self.start_timeout)
7287

7388
def start(self):
7489
"""
7590
Start the JAX container and wait for it to be ready.
7691
"""
7792
super().start()
7893
self._wait_for_container_to_be_ready()
79-
logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}")
94+
logging.info("JAX container started and ready.")
8095
return self
8196

8297
def stop(self, force=True):

0 commit comments

Comments
 (0)