Skip to content

Commit

Permalink
[CI] Add JAX deps in Dockerfiles (apache#14550)
Browse files Browse the repository at this point in the history
* [CI] Add JAX deps in Dockerfiles

* Specify jax/jaxlib/flax version for python3.7
  • Loading branch information
yongwww authored Apr 12, 2023
1 parent c1d1e9f commit 3ef745c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docker/Dockerfile.ci_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ RUN bash /install/ubuntu_install_tensorflow.sh
COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh
RUN bash /install/ubuntu_install_tflite.sh

# JAX deps
COPY install/ubuntu_install_jax.sh /install/ubuntu_install_jax.sh
RUN bash /install/ubuntu_install_jax.sh "cpu"

# Compute Library
COPY install/ubuntu_download_arm_compute_lib_binaries.sh /install/ubuntu_download_arm_compute_lib_binaries.sh
RUN bash /install/ubuntu_download_arm_compute_lib_binaries.sh
Expand Down
3 changes: 3 additions & 0 deletions docker/Dockerfile.ci_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ RUN bash /install/ubuntu_install_coreml.sh
COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh
RUN bash /install/ubuntu_install_tensorflow.sh

COPY install/ubuntu_install_jax.sh /install/ubuntu_install_jax.sh
RUN bash /install/ubuntu_install_jax.sh "cuda"

COPY install/ubuntu_install_darknet.sh /install/ubuntu_install_darknet.sh
RUN bash /install/ubuntu_install_darknet.sh

Expand Down
35 changes: 35 additions & 0 deletions docker/install/ubuntu_install_jax.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -e
set -u
set -o pipefail

# Install jax and jaxlib
if [ "$1" == "cuda" ]; then
pip3 install --upgrade \
jaxlib==0.3.25 \
"jax[cuda11_pip]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
else
pip3 install --upgrade \
jaxlib==0.3.25 \
"jax[cpu]==0.3.25"
fi

# Install flax
pip3 install flax==0.6.4

0 comments on commit 3ef745c

Please sign in to comment.