diff --git a/src/backends/cuda/cuda_device.cpp b/src/backends/cuda/cuda_device.cpp index 2c71511a4..99dbd3590 100644 --- a/src/backends/cuda/cuda_device.cpp +++ b/src/backends/cuda/cuda_device.cpp @@ -1029,14 +1029,18 @@ static constexpr auto required_cuda_version_major = 11; static constexpr auto required_cuda_version_minor = 7; static constexpr auto required_cuda_version = required_cuda_version_major * 1000 + required_cuda_version_minor * 10; -static void initialize() { +[[nodiscard]] +static std::pair initialize() { // global init + auto driver_version { 0 }; + auto device_count { 0 }; + static std::once_flag flag; - std::call_once(flag, [] { + std::call_once(flag, [&] { // CUDA LUISA_CHECK_CUDA(cuInit(0)); + // check driver version - auto driver_version = 0; LUISA_CHECK_CUDA(cuDriverGetVersion(&driver_version)); auto driver_version_major = driver_version / 1000; auto driver_version_minor = (driver_version % 1000) / 10; @@ -1048,25 +1052,23 @@ static void initialize() { LUISA_VERBOSE("Successfully initialized CUDA " "backend with driver version {}.{}.", driver_version_major, driver_version_minor); + + // check device count + LUISA_CHECK_CUDA(cuDeviceGetCount(&device_count)); + if (device_count == 0) { + LUISA_ERROR_WITH_LOCATION("No available device found for CUDA backend."); + } }); + + return { driver_version, device_count }; } }// namespace detail CUDADevice::Handle::Handle(size_t index) noexcept { - - detail::initialize(); - - // cuda - auto driver_version = 0; - LUISA_CHECK_CUDA(cuDriverGetVersion(&driver_version)); + auto [driver_version, device_count] = detail::initialize(); _driver_version = driver_version; - auto device_count = 0; - LUISA_CHECK_CUDA(cuDeviceGetCount(&device_count)); - if (device_count == 0) { - LUISA_ERROR_WITH_LOCATION("No available device found for CUDA backend."); - } if (index == std::numeric_limits::max()) { index = 0; } if (index >= device_count) { LUISA_WARNING_WITH_LOCATION( @@ -1225,9 +1227,8 @@ LUISA_EXPORT_API void destroy(luisa::compute::DeviceInterface *device) noexcept LUISA_EXPORT_API void backend_device_names(luisa::vector &names) noexcept { names.clear(); - auto device_count = 0; - luisa::compute::cuda::detail::initialize(); - LUISA_CHECK_CUDA(cuDeviceGetCount(&device_count)); + auto device_count { 0 }; + std::tie(std::ignore, device_count) = luisa::compute::cuda::detail::initialize(); if (device_count > 0) { names.reserve(device_count); for (auto i = 0; i < device_count; i++) {