Skip to content

Add extra cuBLAS/cuDNN version checks #28931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions jax_plugins/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _version_check(name: str,
get_version,
get_build_version,
scale_for_comparison: int = 1,
min_supported_version: int = 0):
min_supported_version: int = 0) -> int | None:
"""Checks the runtime CUDA component version against the JAX one.

Args:
Expand All @@ -125,6 +125,8 @@ def _version_check(name: str,
min_supported_version: An absolute minimum version required. Must be
passed without rounding down.

Returns: the runtime version, or None if the component is not found.

Raises:
RuntimeError: If the component is not found, or is of unsupported version,
and if raising the error is not deferred till later.
Expand Down Expand Up @@ -162,12 +164,13 @@ def _version_check(name: str,
"version": version,
"minimum_supported": min_supported_version}
results.append(record)
return version

_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
cuda_versions.cuda_runtime_build_version,
scale_for_comparison=10,
min_supported_version=12010)
_version_check(
cudnn_version = _version_check(
"cuDNN",
cuda_versions.cudnn_get_version,
cuda_versions.cudnn_build_version,
Expand All @@ -191,7 +194,7 @@ def _version_check(name: str,
_version_check("cuPTI", cuda_versions.cupti_get_version,
cuda_versions.cupti_build_version,
min_supported_version=18)
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cublas_version = _version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100,
Expand All @@ -202,6 +205,35 @@ def _version_check(name: str,
scale_for_comparison=100,
min_supported_version=12100)

# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-1
if (cudnn_version is not None and cudnn_version == 91000
and cuda_versions.cudnn_build_version() != 91000):
msg = ("cuDNN 9.10.0 had a binary backward-compatibility issue due to reordered enum "
f"values affecting block-scale datatypes. Found runtime version {cudnn_version} "
f"and build version {cuda_versions.cudnn_build_version()}. Please upgrade to "
"9.10.1 or above.")
if raise_on_first_error:
raise RuntimeError(msg)
else:
results.append({"installed": True, "msg": msg, "passed": False})
# xb.local_device_count() cannot safely be called at this point
if xb.CUDA_VISIBLE_DEVICES.value == "all":
local_device_count = cuda_versions.cuda_device_count()
else:
local_device_count = len(xb.CUDA_VISIBLE_DEVICES.value.split(","))
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-0
if (cudnn_version is not None and cudnn_version < 91001
and cublas_version is not None and cublas_version >= 120900
and local_device_count > 1):
msg = (f"cuDNN < 9.10.0 ({cudnn_version} found) had an issue that caused some multi-GPU "
"matmuls, in which the same finalized execution plan is used across different "
f"GPUs, to be functionally incorrect when run with cublasLt >= 12.9 ({cublas_version} "
"found). Please upgrade to 9.10.1 or above.")
if raise_on_first_error:
raise RuntimeError(msg)
else:
results.append({"installed": True, "msg": msg, "passed": False})

errors = []
debug_results = []
for result in results:
Expand Down