Skip to content

Commit 4d0dfd4

Browse files
authored
Merge branch 'main' into support-32647
2 parents 8e9c8d6 + 4cfd54d commit 4d0dfd4

File tree

1,288 files changed

+191872
-97592
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,288 files changed

+191872
-97592
lines changed

.bazelrc

Lines changed: 304 additions & 197 deletions
Large diffs are not rendered by default.

.editorconfig

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,7 @@ indent_style = space
55
end_of_line = lf
66
trim_trailing_whitespace = true
77
insert_final_newline = true
8-
9-
[*.py]
10-
max_line_length = 79
11-
indent_size = 2
12-
13-
[*.rst]
14-
max_line_length = 79
158
indent_size = 2
169

17-
[*.md]
18-
max_line_length = 79
19-
indent_size = 2
20-
21-
[*.yml]
22-
indent_size = 2
10+
[*.py]
11+
max_line_length = 80

.github/ISSUE_TEMPLATE/bug-report.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ body:
2424
2525
[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues
2626
27-
[Raw report]: http://github.com/jax-ml/jax/issues/new
27+
[Raw report]: https://github.com/jax-ml/jax/issues/new?template=none
2828
- type: textarea
2929
attributes:
3030
label: Description

.github/actionlint.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Configuration related to self-hosted runner.
2+
self-hosted-runner:
3+
labels:
4+
- "linux-x86-n4-16" # Linux X86 runner using the 16 vcpu n4-standard-16 machine.
5+
- "linux-x86-n4-32" # Linux X86 runner using the 32 vcpu n4-standard-32 machine.
6+
- "linux-x86-n4-64" # Linux X86 runner using the 64 vcpu n2-standard-64 machine.
7+
- "linux-x86-g2-16-l4-1gpu" # Linux X86 GPU runner using g2-standard-16 machine with 1 NVIDIA L4 GPU attached.
8+
- "linux-x86-g2-48-l4-4gpu" # Linux X86 GPU runner using g2-standard-48 machine with 4 NVIDIA L4 GPUs attached.
9+
- "linux-x86-ct5lp-224-8tpu" # Linux X86 TPU runner using ct5lp-hightpu-8t machine with 2x4 topology.
10+
- "linux-arm64-c4a-16" # Linux ARM64 CPU Runner using the 16 vcpu c4a-standard-16 machine.
11+
- "linux-arm64-c4a-64" # Linux ARM64 CPU Runner using the 64 vcpu c4a-standard-64 machine.
12+
- "windows-x86-n2-16" # Windows X86 runner using n2-standard-16 machine.
13+
- "windows-x86-n2-64" # Windows X86 runner using n2-standard-64 machine.
14+
- "linux-x86-a4-224-b200-1gpu" # Linux X86 GPU runner using 1 B200 GPU and 1/8 the resources of a a4-highgpu-8g machine
15+
- "linux-x86-a3-8g-h100-8gpu" # Linux X86 GPU runner using a3-highgpu-8g machine with 8 NVIDIA H100 GPUs attached.
16+
- "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology.
17+
- "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology.
18+
- "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology.
19+
- "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Composite action to download the jax and jaxlib wheels
2+
name: Download JAX CPU wheels
3+
4+
inputs:
5+
runner:
6+
description: "Which runner type should the wheels be downloaded for?"
7+
type: string
8+
default: "linux-x86-n4-16"
9+
python:
10+
description: "Which python version should the artifact be downloaded for?"
11+
required: true
12+
type: string
13+
jaxlib-version:
14+
description: "Which jaxlib version to download? (head/pypi_latest)"
15+
type: string
16+
default: "head"
17+
skip-download-jaxlib-from-gcs:
18+
description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)"
19+
default: '0'
20+
type: string
21+
gcs_download_uri:
22+
description: "GCS location prefix from where the artifacts should be downloaded"
23+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
24+
type: string
25+
permissions: {}
26+
runs:
27+
using: "composite"
28+
29+
steps:
30+
# Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow.
31+
- name: Set env vars for use in artifact download URL
32+
shell: bash
33+
run: |
34+
os=$(uname -s | awk '{print tolower($0)}')
35+
arch=$(uname -m)
36+
37+
# Adjust os and arch for Windows
38+
if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then
39+
os="win"
40+
arch="amd64"
41+
fi
42+
43+
# Get the major and minor version of Python.
44+
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
45+
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
46+
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
47+
48+
echo "OS=${os}" >> $GITHUB_ENV
49+
echo "ARCH=${arch}" >> $GITHUB_ENV
50+
# Python wheels follow a naming convention: standard wheels use the pattern
51+
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
52+
# `*-cp<py_version>-cp<py_version>t-*`.
53+
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
54+
- name: Download wheels from GCS (non-Windows runs)
55+
shell: bash
56+
id: download-wheel-artifacts-nw
57+
# Set continue-on-error to true to prevent actions from failing the workflow if this step
58+
# fails. Instead, we verify the outcome in the step below so that we can print a more
59+
# informative error message.
60+
continue-on-error: true
61+
if: ${{ !contains(inputs.runner, 'windows-x86') }}
62+
run: |
63+
mkdir -p $(pwd)/dist
64+
gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
65+
66+
if [[ "${{ inputs.skip-download-jaxlib-from-gcs }}" == "1" ]]; then
67+
echo "JAX only release. Only downloading the jax wheel from the release bucket."
68+
else
69+
if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then
70+
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
71+
elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then
72+
PYTHON=python${{ inputs.python }}
73+
$PYTHON -m pip download jaxlib --dest $(pwd)/dist/
74+
else
75+
echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}"
76+
exit 1
77+
fi
78+
fi
79+
- name: Download wheels from GCS (Windows runs)
80+
shell: cmd
81+
id: download-wheel-artifacts-w
82+
# Set continue-on-error to true to prevent actions from failing the workflow if this step
83+
# fails. Instead, we verify the outcome in step below so that we can print a more
84+
# informative error message.
85+
continue-on-error: true
86+
if: ${{ contains(inputs.runner, 'windows-x86') }}
87+
run: |
88+
mkdir dist
89+
@REM Use `call` so that we can run sequential gcloud storage commands on Windows
90+
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
91+
call gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
92+
93+
if "${{ inputs.skip-download-jaxlib-from-gcs }}"=="1" (
94+
echo "JAX only release. Only downloading the jax wheel from the release bucket."
95+
) else (
96+
call gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
97+
)
98+
- name: Skip the test run if the wheel artifacts were not downloaded successfully
99+
shell: bash
100+
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
101+
run: |
102+
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
103+
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
104+
echo "Skipping the test run."
105+
exit 1
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Composite action to download the jax, jaxlib, and the CUDA plugin wheels
2+
name: Download JAX CUDA wheels
3+
4+
inputs:
5+
python:
6+
description: "Which python version should the artifact be downloaded for?"
7+
type: string
8+
required: true
9+
cuda-version:
10+
description: "Which cuda version should the artifact be downloaded for?"
11+
type: string
12+
default: "12"
13+
use-nvidia-pip-wheels:
14+
description: "Whether to download Nvidia CUDA packages from PyPI?"
15+
type: boolean
16+
default: false
17+
jaxlib-version:
18+
description: "Which jaxlib version to download? (head/pypi_latest)"
19+
type: string
20+
default: "head"
21+
download-jax-from-gcs:
22+
description: "Whether to download the jax wheel from GCS"
23+
default: '1'
24+
type: string
25+
skip-download-jaxlib-and-cuda-plugins-from-gcs:
26+
description: "Whether to skip downloading the jaxlib and cuda plugins from GCS (e.g for testing a jax only release)"
27+
default: '0'
28+
type: string
29+
gcs_download_uri:
30+
description: "GCS location prefix from where the artifacts should be downloaded"
31+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
32+
type: string
33+
permissions: {}
34+
runs:
35+
using: "composite"
36+
37+
steps:
38+
# Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow.
39+
- name: Set env vars for use in artifact download URL
40+
shell: bash
41+
run: |
42+
os=$(uname -s | awk '{print tolower($0)}')
43+
arch=$(uname -m)
44+
45+
# Get the major and minor version of Python.
46+
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311
47+
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
48+
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
49+
50+
echo "OS=${os}" >> $GITHUB_ENV
51+
echo "ARCH=${arch}" >> $GITHUB_ENV
52+
# Python wheels follow a naming convention: standard wheels use the pattern
53+
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
54+
# `*-cp<py_version>-cp<py_version>t-*`.
55+
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
56+
57+
# Get the CUDA major version only
58+
full_cuda_version="${{ inputs.cuda-version }}"
59+
echo "JAXCI_CUDA_VERSION=${full_cuda_version%%.*}" >> $GITHUB_ENV
60+
- name: Download wheels
61+
shell: bash
62+
id: download-wheel-artifacts
63+
# Set continue-on-error to true to prevent actions from failing the workflow if this step
64+
# fails. Instead, we verify the outcome in the next step so that we can print a more
65+
# informative error message.
66+
continue-on-error: true
67+
run: |
68+
mkdir -p $(pwd)/dist
69+
if [[ "${{ inputs.download-jax-from-gcs }}" == "1" ]]; then
70+
gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
71+
else
72+
echo "JAX wheel won't be downloaded, only jaxlib pre-built wheel is tested."
73+
fi
74+
75+
# Do not download the jaxlib and CUDA plugin artifacts if we are testing a jax only
76+
# release.
77+
if [[ "${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }}" == "1" ]]; then
78+
echo "JAX only release. Only downloading the jax wheel from the release bucket."
79+
else
80+
if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then
81+
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
82+
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda${JAXCI_CUDA_VERSION}*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
83+
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda${JAXCI_CUDA_VERSION}*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
84+
elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then
85+
PYTHON=python${{ inputs.python }}
86+
$PYTHON -m pip download jaxlib jax-cuda${JAXCI_CUDA_VERSION}-pjrt jax-cuda${JAXCI_CUDA_VERSION}-plugin --dest $(pwd)/dist/
87+
else
88+
echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}"
89+
exit 1
90+
fi
91+
fi
92+
- name: Skip the test run if the wheel artifacts were not downloaded successfully
93+
shell: bash
94+
if: steps.download-wheel-artifacts.outcome == 'failure'
95+
run: |
96+
echo "Failed to download wheel artifacts. Please check if the wheels were"
97+
echo "built successfully by the artifact build jobs and are available in the GCS bucket if
98+
echo "downloading from GCS."
99+
echo "Skipping the test run."
100+
exit 1

.github/workflows/asan.yaml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@ on:
1313
- main
1414
paths:
1515
- '**/workflows/asan.yaml'
16+
permissions: {}
17+
18+
env:
19+
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
20+
PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
1621

1722
jobs:
1823
asan:
1924
# Don't execute in fork due to runner type
2025
if: github.repository == 'jax-ml/jax'
21-
runs-on: linux-x86-n2-64
26+
runs-on: linux-x86-n4-64
2227
container:
2328
image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04
2429
strategy:
@@ -38,14 +43,16 @@ jobs:
3843
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
3944
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
4045
libffi-dev liblzma-dev
41-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
46+
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
4247
with:
4348
path: jax
44-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
49+
persist-credentials: false
50+
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
4551
with:
4652
repository: python/cpython
4753
path: cpython
4854
ref: v3.13.0
55+
persist-credentials: false
4956
- name: Build CPython with ASAN enabled
5057
env:
5158
ASAN_OPTIONS: detect_leaks=0
@@ -71,9 +78,9 @@ jobs:
7178
source ${GITHUB_WORKSPACE}/venv/bin/activate
7279
cd jax
7380
python build/build.py build --wheels=jaxlib --verbose \
81+
--bazel_options=--config=use_tar_archive_files \
7482
--bazel_options=--color=yes \
75-
--bazel_options=--copt=-fsanitize=address \
76-
--clang_path=/usr/bin/clang-18
83+
--bazel_options=--copt=-fsanitize=address
7784
uv pip install dist/jaxlib-*.whl \
7885
-e .
7986
- name: Run tests

.github/workflows/bazel_cpu.yml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# CI - Bazel CPU tests (RBE)
2+
#
3+
# This workflow runs the Bazel CPU tests with wheel dependencies. It can only be triggered by
4+
# other workflows via `workflow_call`. It is used by the `CI - Wheel Tests (Continuous)` and
5+
# `CI - Wheel Tests (Nightly/Release)` workflows to run the Bazel CPU tests.
6+
#
7+
# It consists of the following job:
8+
# run-tests:
9+
# - Downloads the jax, jaxlib from a GCS bucket if build_jaxlib is false. Otherwise,
10+
# the artifacts are built from source.
11+
# - Executes the `run_bazel_test_cpu_rbe.sh` script, which performs the following actions:
12+
# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies.
13+
# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies.
14+
# - `build_jaxlib=true`: Runs the Bazel CPU tests with individual Bazel target dependencies.
15+
16+
name: CI - Bazel CPU tests with wheel dependencies (RBE)
17+
permissions: {}
18+
on:
19+
workflow_call:
20+
inputs:
21+
runner:
22+
description: "Which runner should the workflow run on?"
23+
type: string
24+
default: "linux-x86-n4-16"
25+
python:
26+
description: "Which python version to test?"
27+
type: string
28+
default: "3.12"
29+
enable-x64:
30+
description: "Should x64 mode be enabled?"
31+
type: string
32+
default: "0"
33+
halt-for-connection:
34+
description: 'Should this workflow run wait for a remote connection?'
35+
type: string
36+
default: 'no'
37+
build_jaxlib:
38+
description: 'Should jaxlib be built from source?'
39+
required: true
40+
type: string
41+
build_jax:
42+
description: 'Should jax be built from source?'
43+
required: true
44+
type: string
45+
gcs_download_uri:
46+
description: "GCS location prefix from where the artifacts should be downloaded"
47+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
48+
type: string
49+
50+
jobs:
51+
run-tests:
52+
defaults:
53+
run:
54+
# Explicitly set the shell to bash
55+
shell: bash
56+
runs-on: ${{ inputs.runner }}
57+
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') ||
58+
(contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') ||
59+
(contains(inputs.runner, 'windows-x86') && null) }}
60+
env:
61+
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
62+
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
63+
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
64+
JAXCI_BUILD_JAX: ${{ inputs.build_jax }}
65+
66+
# Begin Presubmit Naming Check - name modification requires internal check to be updated
67+
name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') ||
68+
(contains(inputs.runner, 'linux-arm64') && 'linux arm64') ||
69+
(contains(inputs.runner, 'windows-x86') && 'windows x86') }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
70+
# End Presubmit Naming Check github-cpu-presubmits
71+
72+
steps:
73+
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
74+
with:
75+
persist-credentials: false
76+
- name: Download JAX CPU wheels
77+
if: inputs.build_jaxlib == 'false'
78+
uses: ./.github/actions/download-jax-cpu-wheels
79+
with:
80+
runner: ${{ inputs.runner }}
81+
python: ${{ inputs.python }}
82+
gcs_download_uri: ${{ inputs.gcs_download_uri }}
83+
# Halt for testing
84+
- name: Wait For Connection
85+
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
86+
with:
87+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
88+
- name: "Bazel CPU tests with build_jaxlib=${{ format('{0}', inputs.build_jaxlib) }}"
89+
timeout-minutes: 60
90+
run: ./ci/run_bazel_test_cpu_rbe.sh

0 commit comments

Comments
 (0)