Skip to content

Parametrize build system on CUDA major version #28968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
}

_JAX_CUDA_VERSION = "12"

def add_global_arguments(parser: argparse.ArgumentParser):
"""Adds all the global arguments that applies to all the CLI subcommands."""
parser.add_argument(
Expand Down Expand Up @@ -642,6 +640,13 @@ async def main():
# https://peps.python.org/pep-0440/
wheel_git_hash = option.split("=")[-1].lstrip('0')[:9]

if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
if "cuda" in args.wheels:
wheel_build_command_base.append(f"--//jax:cuda_major_version={cuda_major_version}")

with open(".jax_configure.bazelrc", "w") as f:
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
if not jax_configure_options:
Expand Down Expand Up @@ -709,10 +714,6 @@ async def main():

if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")

if "rocm" in wheel:
Expand All @@ -738,7 +739,7 @@ async def main():
else:
bazel_dir = jaxlib_and_plugins_bazel_dir
if "cuda" in wheel:
wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace(
wheel_dir = wheel.replace("cuda", f"cuda{cuda_major_version}").replace(
"-", "_"
)
else:
Expand Down
3 changes: 3 additions & 0 deletions build/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ jaxlib

# The with-cuda extra also includes NVIDIA's pip packages.
jax-cuda12-plugin[with-cuda] ; sys_platform == "linux"
jax-cuda13-plugin ; sys_platform == "does_not_exist"
jax-cuda12-pjrt ; sys_platform == "linux"
jax-cuda13-pjrt ; sys_platform == "does_not_exist"

# TPU dependencies
libtpu ; sys_platform == "linux" and platform_machine == "x86_64"

# For Mosaic GPU collectives
nvidia-nvshmem-cu12>=3.2.5 ; sys_platform == "linux"
nvidia-nvshmem-cu13
4 changes: 4 additions & 0 deletions build/requirements_lock_3_12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,10 @@ nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \
--hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \
--hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00
# via -r build/requirements.in
nvidia-nvshmem-cu13==0.0.0a0 \
--hash=sha256:84d265d7b97dae6ee74139f8f7e37fc65a63e4ebb7287b987a4dca0c0625673d \
--hash=sha256:b6900e44e6be1e0e7be6059c5b7a397fb3cb84914784571ab7e20a35bb2b140d
# via -r build/requirements.in
opt-einsum==3.3.0 \
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549
Expand Down
24 changes: 24 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ config_setting(
},
)

# Which major version of CUDA to target
string_flag(
name = "cuda_major_version",
build_setting_default = "12",
values = [
"12",
"13",
],
)

config_setting(
name = "config_cuda_major_version_12",
flag_values = {
":cuda_major_version": "12",
}
)

config_setting(
name = "config_cuda_major_version_13",
flag_values = {
":cuda_major_version": "13",
}
)

# The flag controls whether jax should be built by Bazel.
# If ":build_jax=true", then jax will be built.
# If ":build_jax=false", then jax is not built. It is assumed that the pre-built jax wheel
Expand Down
11 changes: 7 additions & 4 deletions jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import importlib
import gc
import os
import pathlib
Expand Down Expand Up @@ -113,13 +114,15 @@ def _xla_gc_callback(*args):
xla_client._xla.collect_garbage()
gc.callbacks.append(_xla_gc_callback)

try:
import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error # noqa: F401
except ImportError:
for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']:
try:
import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error # noqa: F401
cuda_versions = importlib.import_module(
f'{pkg_name}._versions'
)
except ImportError:
cuda_versions = None
else:
break

import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/lib/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
try:
from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error
except ImportError:
from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error
try:
from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error
except ImportError:
from jax_cuda13_plugin import _mosaic_gpu_ext # pytype: disable=import-error
except ImportError as e:
raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e
2 changes: 1 addition & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@

export = set_module('jax.numpy')

for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'
Expand Down
5 changes: 4 additions & 1 deletion jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,10 @@ def as_torch_gpu_kernel(

# Get our hands on the compilation and unload functions
try:
import jax_plugins.xla_cuda12 as cuda_plugin
try:
import jax_plugins.xla_cuda13 as cuda_plugin
except ImportError:
import jax_plugins.xla_cuda12 as cuda_plugin
except ImportError:
raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds "
"that use backend plugins")
Expand Down
2 changes: 1 addition & 1 deletion jax_plugins/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax cuda plugin packages.
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'
Expand Down
8 changes: 8 additions & 0 deletions jaxlib/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -682,3 +682,11 @@ py_library(
":cuda_plugin_extension",
]),
)

py_library(
name = "nvidia_nvshmem",
deps = select({
"//jax:config_cuda_major_version_12": ["@pypi//nvidia_nvshmem_cu12"],
"//jax:config_cuda_major_version_13": ["@pypi//nvidia_nvshmem_cu13"],
})
)
12 changes: 9 additions & 3 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -180,24 +180,30 @@ def _cpu_test_deps():
"//jax:config_build_jaxlib_wheel": ["//jaxlib/tools:jaxlib_py_import"],
})

def cuda_major_version():
return select({
"//jax:config_cuda_major_version_12": "12",
"//jax:config_cuda_major_version_13": "13",
})

def _gpu_test_deps():
"""Returns the additional dependencies needed for a GPU test."""
return select({
"//jax:config_build_jaxlib_true": [
"//jaxlib/cuda:gpu_only_test_deps",
"//jaxlib/rocm:gpu_only_test_deps",
"//jax_plugins:gpu_plugin_only_test_deps",
"@pypi//nvidia_nvshmem_cu12",
"//jaxlib/cuda:nvidia_nvshmem",
],
"//jax:config_build_jaxlib_false": [
"//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps",
"//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps",
"@pypi//nvidia_nvshmem_cu12",
"//jaxlib/cuda:nvidia_nvshmem",
],
"//jax:config_build_jaxlib_wheel": [
"//jaxlib/tools:jax_cuda_plugin_py_import",
"//jaxlib/tools:jax_cuda_pjrt_py_import",
"@pypi//nvidia_nvshmem_cu12",
"//jaxlib/cuda:nvidia_nvshmem",
],
})

Expand Down
10 changes: 5 additions & 5 deletions jaxlib/plugin_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from .version import __version__ as jaxlib_version


_PLUGIN_MODULE_NAME = {
"cuda": "jax_cuda12_plugin",
"rocm": "jax_rocm60_plugin",
_PLUGIN_MODULE_NAMES = {
"cuda": ["jax_cuda13_plugin", "jax_cuda12_plugin"],
"rocm": ["jax_rocm60_plugin"],
}


Expand All @@ -44,10 +44,10 @@ def import_from_plugin(
The imported submodule, or None if the plugin is not installed or if the
versions are incompatible.
"""
if plugin_name not in _PLUGIN_MODULE_NAME:
if plugin_name not in _PLUGIN_MODULE_NAMES:
raise ValueError(f"Unknown plugin: {plugin_name}")
return maybe_import_plugin_submodule(
[f".{plugin_name}", _PLUGIN_MODULE_NAME[plugin_name]],
[f".{plugin_name}"] + _PLUGIN_MODULE_NAMES[plugin_name],
submodule_name,
check_version=check_version,
)
Expand Down
81 changes: 55 additions & 26 deletions jaxlib/tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ load(
"pytype_strict_library",
"pytype_test",
"wheel_sources",
"cuda_major_version",
)

licenses(["notice"]) # Apache 2
Expand Down Expand Up @@ -330,22 +331,26 @@ jax_wheel(
name = "jax_cuda_plugin_wheel",
enable_cuda = True,
no_abi = False,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
platform_version = cuda_major_version(),
source_files = [":jax_plugin_sources"],
wheel_binary = ":build_gpu_kernels_wheel_tool",
wheel_name = "jax_cuda12_plugin",
wheel_name = select({
"//jax:config_cuda_major_version_12": "jax_cuda12_plugin",
"//jax:config_cuda_major_version_13": "jax_cuda13_plugin",
}),
)

jax_wheel(
name = "jax_cuda_plugin_wheel_editable",
editable = True,
enable_cuda = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
platform_version = cuda_major_version(),
source_files = [":jax_plugin_sources"],
wheel_binary = ":build_gpu_kernels_wheel_tool",
wheel_name = "jax_cuda12_plugin",
wheel_name = select({
"//jax:config_cuda_major_version_12": "jax_cuda12_plugin",
"//jax:config_cuda_major_version_13": "jax_cuda13_plugin",
}),
)

jax_wheel(
Expand Down Expand Up @@ -417,22 +422,26 @@ jax_wheel(
name = "jax_cuda_pjrt_wheel",
enable_cuda = True,
no_abi = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
platform_version = cuda_major_version(),
source_files = [":jax_pjrt_sources"],
wheel_binary = ":build_gpu_plugin_wheel_tool",
wheel_name = "jax_cuda12_pjrt",
wheel_name = select({
"//jax:config_cuda_major_version_12": "jax_cuda12_pjrt",
"//jax:config_cuda_major_version_13": "jax_cuda13_pjrt",
}),
)

jax_wheel(
name = "jax_cuda_pjrt_wheel_editable",
editable = True,
enable_cuda = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
platform_version = cuda_major_version(),
source_files = [":jax_pjrt_sources"],
wheel_binary = ":build_gpu_plugin_wheel_tool",
wheel_name = "jax_cuda12_pjrt",
wheel_name = select({
"//jax:config_cuda_major_version_12": "jax_cuda12_pjrt",
"//jax:config_cuda_major_version_13": "jax_cuda13_pjrt",
}),
)

jax_wheel(
Expand All @@ -458,18 +467,32 @@ jax_wheel(
# Py_import targets.
filegroup(
name = "nvidia_wheel_deps",
srcs = [
"@pypi_nvidia_cublas_cu12//:whl",
"@pypi_nvidia_cuda_cupti_cu12//:whl",
"@pypi_nvidia_cuda_nvcc_cu12//:whl",
"@pypi_nvidia_cuda_runtime_cu12//:whl",
"@pypi_nvidia_cudnn_cu12//:whl",
"@pypi_nvidia_cufft_cu12//:whl",
"@pypi_nvidia_cusolver_cu12//:whl",
"@pypi_nvidia_cusparse_cu12//:whl",
"@pypi_nvidia_nccl_cu12//:whl",
"@pypi_nvidia_nvjitlink_cu12//:whl",
],
srcs = select({
"//jax:config_cuda_major_version_12": [
"@pypi_nvidia_cublas_cu12//:whl",
"@pypi_nvidia_cuda_cupti_cu12//:whl",
"@pypi_nvidia_cuda_nvcc_cu12//:whl",
"@pypi_nvidia_cuda_runtime_cu12//:whl",
"@pypi_nvidia_cudnn_cu12//:whl",
"@pypi_nvidia_cufft_cu12//:whl",
"@pypi_nvidia_cusolver_cu12//:whl",
"@pypi_nvidia_cusparse_cu12//:whl",
"@pypi_nvidia_nccl_cu12//:whl",
"@pypi_nvidia_nvjitlink_cu12//:whl",
],
"//jax:config_cuda_major_version_13": [
"@pypi_nvidia_cublas_cu13//:whl",
"@pypi_nvidia_cuda_cupti_cu13//:whl",
"@pypi_nvidia_cuda_nvcc_cu13//:whl",
"@pypi_nvidia_cuda_runtime_cu13//:whl",
"@pypi_nvidia_cudnn_cu13//:whl",
"@pypi_nvidia_cufft_cu13//:whl",
"@pypi_nvidia_cusolver_cu13//:whl",
"@pypi_nvidia_cusparse_cu13//:whl",
"@pypi_nvidia_nccl_cu13//:whl",
"@pypi_nvidia_nvjitlink_cu13//:whl",
],
}),
)

# The flag configures whether to add the pypi NVIDIA CUDA deps to py_import.
Expand Down Expand Up @@ -506,13 +529,19 @@ py_import(
# The targets below are used for GPU tests with `--//jax:build_jaxlib=false`.
py_import(
name = "pypi_jax_cuda_plugin_with_cuda_deps",
wheel = "@pypi_jax_cuda12_plugin//:whl",
wheel = select({
"//jax:config_cuda_major_version_12": "@pypi_jax_cuda12_plugin//:whl",
"//jax:config_cuda_major_version_13": "@pypi_jax_cuda13_plugin//:whl",
}),
wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]),
)

py_import(
name = "pypi_jax_cuda_pjrt_with_cuda_deps",
wheel = "@pypi_jax_cuda12_pjrt//:whl",
wheel = select({
"//jax:config_cuda_major_version_12": "@pypi_jax_cuda12_pjrt//:whl",
"//jax:config_cuda_major_version_13": "@pypi_jax_cuda13_pjrt//:whl",
}),
wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]),
)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module = [
"jax.experimental.jax2tf.tests.back_compat_testdata",
"jax.experimental.jax2tf.tests.flax_models",
"jax_cuda12_plugin.*",
"jax_cuda13_plugin.*",
"jaxlib.cpu_feature_guard",
"jaxlib.cuda.*",
"jaxlib.mlir.*",
Expand Down