diff --git a/build/build.py b/build/build.py index d059251552eb..9fddc62d925a 100755 --- a/build/build.py +++ b/build/build.py @@ -80,8 +80,6 @@ "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", } -_JAX_CUDA_VERSION = "12" - def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" parser.add_argument( @@ -642,6 +640,13 @@ async def main(): # https://peps.python.org/pep-0440/ wheel_git_hash = option.split("=")[-1].lstrip('0')[:9] + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = args.cuda_major_version + if "cuda" in args.wheels: + wheel_build_command_base.append(f"--//jax:cuda_major_version={cuda_major_version}") + with open(".jax_configure.bazelrc", "w") as f: jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule) if not jax_configure_options: @@ -709,10 +714,6 @@ async def main(): if "cuda" in wheel: wheel_build_command.append("--enable-cuda=True") - if args.cuda_version: - cuda_major_version = args.cuda_version.split(".")[0] - else: - cuda_major_version = args.cuda_major_version wheel_build_command.append(f"--platform_version={cuda_major_version}") if "rocm" in wheel: @@ -738,7 +739,7 @@ async def main(): else: bazel_dir = jaxlib_and_plugins_bazel_dir if "cuda" in wheel: - wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace( + wheel_dir = wheel.replace("cuda", f"cuda{cuda_major_version}").replace( "-", "_" ) else: diff --git a/build/requirements.in b/build/requirements.in index c1be7a250bff..35251bb31d0a 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -20,10 +20,13 @@ jaxlib==0.6.1 # The with-cuda extra also includes NVIDIA's pip packages. jax-cuda12-plugin[with-cuda]==0.6.1 ; sys_platform == "linux" +jax-cuda13-plugin ; sys_platform == "does_not_exist" jax-cuda12-pjrt==0.6.1 ; sys_platform == "linux" +jax-cuda13-pjrt ; sys_platform == "does_not_exist" # TPU dependencies libtpu ; sys_platform == "linux" and platform_machine == "x86_64" # For Mosaic GPU collectives nvidia-nvshmem-cu12>=3.2.5 ; sys_platform == "linux" +nvidia-nvshmem-cu13 diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 04c6990da696..d397f20e1a68 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -492,6 +492,10 @@ nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ # via # -r build/requirements.in # jax-cuda12-plugin +nvidia-nvshmem-cu13==0.0.0a0 \ + --hash=sha256:84d265d7b97dae6ee74139f8f7e37fc65a63e4ebb7287b987a4dca0c0625673d \ + --hash=sha256:b6900e44e6be1e0e7be6059c5b7a397fb3cb84914784571ab7e20a35bb2b140d + # via -r build/requirements.in opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/jax/BUILD b/jax/BUILD index 396e6fdf6ed4..8472042640ad 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -84,6 +84,30 @@ config_setting( }, ) +# Which major version of CUDA to target +string_flag( + name = "cuda_major_version", + build_setting_default = "12", + values = [ + "12", + "13", + ], +) + +config_setting( + name = "config_cuda_major_version_12", + flag_values = { + ":cuda_major_version": "12", + } +) + +config_setting( + name = "config_cuda_major_version_13", + flag_values = { + ":cuda_major_version": "13", + } +) + # The flag controls whether jax should be built by Bazel. # If ":build_jax=true", then jax will be built. # If ":build_jax=false", then jax is not built. It is assumed that the pre-built jax wheel diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 8de05061ec99..f6f5c4ea9468 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations +import importlib import gc import os import pathlib @@ -113,13 +114,15 @@ def _xla_gc_callback(*args): xla_client._xla.collect_garbage() gc.callbacks.append(_xla_gc_callback) -try: - import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error # noqa: F401 -except ImportError: +for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']: try: - import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error # noqa: F401 + cuda_versions = importlib.import_module( + f'{pkg_name}._versions' + ) except ImportError: cuda_versions = None + else: + break import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401 diff --git a/jax/_src/lib/mosaic_gpu.py b/jax/_src/lib/mosaic_gpu.py index 494112093029..37c190a409c5 100644 --- a/jax/_src/lib/mosaic_gpu.py +++ b/jax/_src/lib/mosaic_gpu.py @@ -18,6 +18,9 @@ try: from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error except ImportError: - from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error + try: + from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error + except ImportError: + from jax_cuda13_plugin import _mosaic_gpu_ext # pytype: disable=import-error except ImportError as e: raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ad2b3ad6aa75..915a0b7232c4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -73,7 +73,7 @@ export = set_module('jax.numpy') -for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: +for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']: try: cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 9464bb587c71..ad72f3e2c885 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -784,7 +784,10 @@ def as_torch_gpu_kernel( # Get our hands on the compilation and unload functions try: - import jax_plugins.xla_cuda12 as cuda_plugin + try: + import jax_plugins.xla_cuda13 as cuda_plugin + except ImportError: + import jax_plugins.xla_cuda12 as cuda_plugin except ImportError: raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " "that use backend plugins") diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 02bcbcf16dbc..59119a23df41 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -26,7 +26,7 @@ # cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # preinstalled jax cuda plugin packages. -for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: +for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']: try: cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index eabb3157ecca..744f1e561e9d 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -682,3 +682,11 @@ py_library( ":cuda_plugin_extension", ]), ) + +py_library( + name = "nvidia_nvshmem", + deps = select({ + "//jax:config_cuda_major_version_12": ["@pypi//nvidia_nvshmem_cu12"], + "//jax:config_cuda_major_version_13": ["@pypi//nvidia_nvshmem_cu13"], + }) +) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 678d92bc434a..1964f06c025d 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -180,6 +180,12 @@ def _cpu_test_deps(): "//jax:config_build_jaxlib_wheel": ["//jaxlib/tools:jaxlib_py_import"], }) +def cuda_major_version(): + return select({ + "//jax:config_cuda_major_version_12": "12", + "//jax:config_cuda_major_version_13": "13", + }) + def _gpu_test_deps(): """Returns the additional dependencies needed for a GPU test.""" return select({ @@ -187,17 +193,17 @@ def _gpu_test_deps(): "//jaxlib/cuda:gpu_only_test_deps", "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", - "@pypi//nvidia_nvshmem_cu12", + "//jaxlib/cuda:nvidia_nvshmem", ], "//jax:config_build_jaxlib_false": [ "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", - "@pypi//nvidia_nvshmem_cu12", + "//jaxlib/cuda:nvidia_nvshmem", ], "//jax:config_build_jaxlib_wheel": [ "//jaxlib/tools:jax_cuda_plugin_py_import", "//jaxlib/tools:jax_cuda_pjrt_py_import", - "@pypi//nvidia_nvshmem_cu12", + "//jaxlib/cuda:nvidia_nvshmem", ], }) diff --git a/jaxlib/plugin_support.py b/jaxlib/plugin_support.py index ea24dc181be0..1d629d17e64a 100644 --- a/jaxlib/plugin_support.py +++ b/jaxlib/plugin_support.py @@ -21,9 +21,9 @@ from .version import __version__ as jaxlib_version -_PLUGIN_MODULE_NAME = { - "cuda": "jax_cuda12_plugin", - "rocm": "jax_rocm60_plugin", +_PLUGIN_MODULE_NAMES = { + "cuda": ["jax_cuda13_plugin", "jax_cuda12_plugin"], + "rocm": ["jax_rocm60_plugin"], } @@ -44,10 +44,10 @@ def import_from_plugin( The imported submodule, or None if the plugin is not installed or if the versions are incompatible. """ - if plugin_name not in _PLUGIN_MODULE_NAME: + if plugin_name not in _PLUGIN_MODULE_NAMES: raise ValueError(f"Unknown plugin: {plugin_name}") return maybe_import_plugin_submodule( - [f".{plugin_name}", _PLUGIN_MODULE_NAME[plugin_name]], + [f".{plugin_name}"] + _PLUGIN_MODULE_NAMES[plugin_name], submodule_name, check_version=check_version, ) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 7f8a85a5d9ab..6b080e42d450 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -35,6 +35,7 @@ load( "pytype_strict_library", "pytype_test", "wheel_sources", + "cuda_major_version", ) licenses(["notice"]) # Apache 2 @@ -330,22 +331,26 @@ jax_wheel( name = "jax_cuda_plugin_wheel", enable_cuda = True, no_abi = False, - # TODO(b/371217563) May use hermetic cuda version here. - platform_version = "12", + platform_version = cuda_major_version(), source_files = [":jax_plugin_sources"], wheel_binary = ":build_gpu_kernels_wheel_tool", - wheel_name = "jax_cuda12_plugin", + wheel_name = select({ + "//jax:config_cuda_major_version_12": "jax_cuda12_plugin", + "//jax:config_cuda_major_version_13": "jax_cuda13_plugin", + }), ) jax_wheel( name = "jax_cuda_plugin_wheel_editable", editable = True, enable_cuda = True, - # TODO(b/371217563) May use hermetic cuda version here. - platform_version = "12", + platform_version = cuda_major_version(), source_files = [":jax_plugin_sources"], wheel_binary = ":build_gpu_kernels_wheel_tool", - wheel_name = "jax_cuda12_plugin", + wheel_name = select({ + "//jax:config_cuda_major_version_12": "jax_cuda12_plugin", + "//jax:config_cuda_major_version_13": "jax_cuda13_plugin", + }), ) jax_wheel( @@ -417,22 +422,26 @@ jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, no_abi = True, - # TODO(b/371217563) May use hermetic cuda version here. - platform_version = "12", + platform_version = cuda_major_version(), source_files = [":jax_pjrt_sources"], wheel_binary = ":build_gpu_plugin_wheel_tool", - wheel_name = "jax_cuda12_pjrt", + wheel_name = select({ + "//jax:config_cuda_major_version_12": "jax_cuda12_pjrt", + "//jax:config_cuda_major_version_13": "jax_cuda13_pjrt", + }), ) jax_wheel( name = "jax_cuda_pjrt_wheel_editable", editable = True, enable_cuda = True, - # TODO(b/371217563) May use hermetic cuda version here. - platform_version = "12", + platform_version = cuda_major_version(), source_files = [":jax_pjrt_sources"], wheel_binary = ":build_gpu_plugin_wheel_tool", - wheel_name = "jax_cuda12_pjrt", + wheel_name = select({ + "//jax:config_cuda_major_version_12": "jax_cuda12_pjrt", + "//jax:config_cuda_major_version_13": "jax_cuda13_pjrt", + }), ) jax_wheel( @@ -458,18 +467,32 @@ jax_wheel( # Py_import targets. filegroup( name = "nvidia_wheel_deps", - srcs = [ - "@pypi_nvidia_cublas_cu12//:whl", - "@pypi_nvidia_cuda_cupti_cu12//:whl", - "@pypi_nvidia_cuda_nvcc_cu12//:whl", - "@pypi_nvidia_cuda_runtime_cu12//:whl", - "@pypi_nvidia_cudnn_cu12//:whl", - "@pypi_nvidia_cufft_cu12//:whl", - "@pypi_nvidia_cusolver_cu12//:whl", - "@pypi_nvidia_cusparse_cu12//:whl", - "@pypi_nvidia_nccl_cu12//:whl", - "@pypi_nvidia_nvjitlink_cu12//:whl", - ], + srcs = select({ + "//jax:config_cuda_major_version_12": [ + "@pypi_nvidia_cublas_cu12//:whl", + "@pypi_nvidia_cuda_cupti_cu12//:whl", + "@pypi_nvidia_cuda_nvcc_cu12//:whl", + "@pypi_nvidia_cuda_runtime_cu12//:whl", + "@pypi_nvidia_cudnn_cu12//:whl", + "@pypi_nvidia_cufft_cu12//:whl", + "@pypi_nvidia_cusolver_cu12//:whl", + "@pypi_nvidia_cusparse_cu12//:whl", + "@pypi_nvidia_nccl_cu12//:whl", + "@pypi_nvidia_nvjitlink_cu12//:whl", + ], + "//jax:config_cuda_major_version_13": [ + "@pypi_nvidia_cublas_cu13//:whl", + "@pypi_nvidia_cuda_cupti_cu13//:whl", + "@pypi_nvidia_cuda_nvcc_cu13//:whl", + "@pypi_nvidia_cuda_runtime_cu13//:whl", + "@pypi_nvidia_cudnn_cu13//:whl", + "@pypi_nvidia_cufft_cu13//:whl", + "@pypi_nvidia_cusolver_cu13//:whl", + "@pypi_nvidia_cusparse_cu13//:whl", + "@pypi_nvidia_nccl_cu13//:whl", + "@pypi_nvidia_nvjitlink_cu13//:whl", + ], + }), ) # The flag configures whether to add the pypi NVIDIA CUDA deps to py_import. @@ -506,13 +529,19 @@ py_import( # The targets below are used for GPU tests with `--//jax:build_jaxlib=false`. py_import( name = "pypi_jax_cuda_plugin_with_cuda_deps", - wheel = "@pypi_jax_cuda12_plugin//:whl", + wheel = select({ + "//jax:config_cuda_major_version_12": "@pypi_jax_cuda12_plugin//:whl", + "//jax:config_cuda_major_version_13": "@pypi_jax_cuda13_plugin//:whl", + }), wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), ) py_import( name = "pypi_jax_cuda_pjrt_with_cuda_deps", - wheel = "@pypi_jax_cuda12_pjrt//:whl", + wheel = select({ + "//jax:config_cuda_major_version_12": "@pypi_jax_cuda12_pjrt//:whl", + "//jax:config_cuda_major_version_13": "@pypi_jax_cuda13_pjrt//:whl", + }), wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), ) diff --git a/pyproject.toml b/pyproject.toml index d48351197b54..4b9e81e19851 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ module = [ "jax.experimental.jax2tf.tests.back_compat_testdata", "jax.experimental.jax2tf.tests.flax_models", "jax_cuda12_plugin.*", + "jax_cuda13_plugin.*", "jaxlib.cpu_feature_guard", "jaxlib.cuda.*", "jaxlib.mlir.*",