Skip to content

Commit

Permalink
avoid repeated calls to 'cuDriverGetVersion' or 'cuDeviceGetCount' fu…
Browse files Browse the repository at this point in the history
…nction
  • Loading branch information
xiaomx32 committed Feb 27, 2025
1 parent c12fc6d commit 28b7b28
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions src/backends/cuda/cuda_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1029,14 +1029,18 @@ static constexpr auto required_cuda_version_major = 11;
static constexpr auto required_cuda_version_minor = 7;
static constexpr auto required_cuda_version = required_cuda_version_major * 1000 + required_cuda_version_minor * 10;

static void initialize() {
[[nodiscard]]
static std::pair<int, int> initialize() {
// global init
auto driver_version { 0 };
auto device_count { 0 };

static std::once_flag flag;
std::call_once(flag, [] {
std::call_once(flag, [&] {
// CUDA
LUISA_CHECK_CUDA(cuInit(0));

// check driver version
auto driver_version = 0;
LUISA_CHECK_CUDA(cuDriverGetVersion(&driver_version));
auto driver_version_major = driver_version / 1000;
auto driver_version_minor = (driver_version % 1000) / 10;
Expand All @@ -1048,25 +1052,23 @@ static void initialize() {
LUISA_VERBOSE("Successfully initialized CUDA "
"backend with driver version {}.{}.",
driver_version_major, driver_version_minor);

// check device count
LUISA_CHECK_CUDA(cuDeviceGetCount(&device_count));
if (device_count == 0) {
LUISA_ERROR_WITH_LOCATION("No available device found for CUDA backend.");
}
});

return { driver_version, device_count };
}

}// namespace detail

CUDADevice::Handle::Handle(size_t index) noexcept {

detail::initialize();

// cuda
auto driver_version = 0;
LUISA_CHECK_CUDA(cuDriverGetVersion(&driver_version));
auto [driver_version, device_count] = detail::initialize();
_driver_version = driver_version;

auto device_count = 0;
LUISA_CHECK_CUDA(cuDeviceGetCount(&device_count));
if (device_count == 0) {
LUISA_ERROR_WITH_LOCATION("No available device found for CUDA backend.");
}
if (index == std::numeric_limits<size_t>::max()) { index = 0; }
if (index >= device_count) {
LUISA_WARNING_WITH_LOCATION(
Expand Down Expand Up @@ -1225,9 +1227,8 @@ LUISA_EXPORT_API void destroy(luisa::compute::DeviceInterface *device) noexcept

LUISA_EXPORT_API void backend_device_names(luisa::vector<luisa::string> &names) noexcept {
names.clear();
auto device_count = 0;
luisa::compute::cuda::detail::initialize();
LUISA_CHECK_CUDA(cuDeviceGetCount(&device_count));
auto device_count { 0 };
std::tie(std::ignore, device_count) = luisa::compute::cuda::detail::initialize();
if (device_count > 0) {
names.reserve(device_count);
for (auto i = 0; i < device_count; i++) {
Expand Down

0 comments on commit 28b7b28

Please sign in to comment.