Skip to content

Commit 0cdddf8

Browse files
committed
add jax to pyproject.toml
1 parent b701be0 commit 0cdddf8

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

modules/jax/testcontainers/jax/__init__.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,23 @@ class JAXContainer(DockerContainer):
1111
1212
Example:
1313
14-
.. doctest::
14+
.. doctest::
1515
16-
>>> import jax
17-
>>> from testcontainers.jax import JAXContainer
16+
>>> import jax
17+
>>> from testcontainers.jax import JAXContainer
1818
19-
>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
20-
... # Connect to the container
21-
... jax_container.connect()
22-
...
23-
... # Run a simple JAX computation
24-
... result = jax.numpy.add(1, 1)
25-
... assert result == 2
19+
>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
20+
... # Connect to the container
21+
... jax_container.connect()
22+
...
23+
... # Run a simple JAX computation
24+
... result = jax.numpy.add(1, 1)
25+
... assert result == 2
26+
27+
.. auto-class:: JAXContainer
28+
:members:
29+
:undoc-members:
30+
:show-inheritance:
2631
"""
2732

2833
def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ neo4j = ["neo4j"]
149149
nginx = []
150150
opensearch = ["opensearch-py"]
151151
ollama = []
152+
jax = ["jax"]
152153
oracle = ["sqlalchemy", "oracledb"]
153154
oracle-free = ["sqlalchemy", "oracledb"]
154155
postgres = []

0 commit comments

Comments
 (0)