diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 02bcbcf16dbc..f1e3c55811dc 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -114,7 +114,7 @@ def _version_check(name: str, get_version, get_build_version, scale_for_comparison: int = 1, - min_supported_version: int = 0): + min_supported_version: int = 0) -> int | None: """Checks the runtime CUDA component version against the JAX one. Args: @@ -125,6 +125,8 @@ def _version_check(name: str, min_supported_version: An absolute minimum version required. Must be passed without rounding down. + Returns: the runtime version, or None if the component is not found. + Raises: RuntimeError: If the component is not found, or is of unsupported version, and if raising the error is not deferred till later. @@ -162,12 +164,13 @@ def _version_check(name: str, "version": version, "minimum_supported": min_supported_version} results.append(record) + return version _version_check("CUDA", cuda_versions.cuda_runtime_get_version, cuda_versions.cuda_runtime_build_version, scale_for_comparison=10, min_supported_version=12010) - _version_check( + cudnn_version = _version_check( "cuDNN", cuda_versions.cudnn_get_version, cuda_versions.cudnn_build_version, @@ -191,7 +194,7 @@ def _version_check(name: str, _version_check("cuPTI", cuda_versions.cupti_get_version, cuda_versions.cupti_build_version, min_supported_version=18) - _version_check("cuBLAS", cuda_versions.cublas_get_version, + cublas_version = _version_check("cuBLAS", cuda_versions.cublas_get_version, cuda_versions.cublas_build_version, # Ignore patch versions. scale_for_comparison=100, @@ -202,6 +205,35 @@ def _version_check(name: str, scale_for_comparison=100, min_supported_version=12100) + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-1 + if (cudnn_version is not None and cudnn_version == 91000 + and cuda_versions.cudnn_build_version() != 91000): + msg = ("cuDNN 9.10.0 had a binary backward-compatibility issue due to reordered enum " + f"values affecting block-scale datatypes. Found runtime version {cudnn_version} " + f"and build version {cuda_versions.cudnn_build_version()}. Please upgrade to " + "9.10.1 or above.") + if raise_on_first_error: + raise RuntimeError(msg) + else: + results.append({"installed": True, "msg": msg, "passed": False}) + # xb.local_device_count() cannot safely be called at this point + if xb.CUDA_VISIBLE_DEVICES.value == "all": + local_device_count = cuda_versions.cuda_device_count() + else: + local_device_count = len(xb.CUDA_VISIBLE_DEVICES.value.split(",")) + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-0 + if (cudnn_version is not None and cudnn_version < 91001 + and cublas_version is not None and cublas_version >= 120900 + and local_device_count > 1): + msg = (f"cuDNN < 9.10.0 ({cudnn_version} found) had an issue that caused some multi-GPU " + "matmuls, in which the same finalized execution plan is used across different " + f"GPUs, to be functionally incorrect when run with cublasLt >= 12.9 ({cublas_version} " + "found). Please upgrade to 9.10.1 or above.") + if raise_on_first_error: + raise RuntimeError(msg) + else: + results.append({"installed": True, "msg": msg, "passed": False}) + errors = [] debug_results = [] for result in results: