Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
id: cpp_files
run: |
FILES=$(git diff --name-only --diff-filter=ACMR origin/${{ github.base_ref }}...HEAD -- \
'*.c' '*.cc' '*.cpp' '*.cxx' | tr '\n' ' ')
src/ tests/ | grep -E '\.(c|cc|cpp|cxx)$' | tr '\n' ' ')
echo "files=$FILES" >> $GITHUB_OUTPUT
[ -n "$FILES" ] && echo "changed=true" >> $GITHUB_OUTPUT || echo "changed=false" >> $GITHUB_OUTPUT
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/cubin_launcher.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ To use dynamic shared memory, specify the size in the :cpp:func:`tvm::ffi::Cubin

// Allocate 1KB of dynamic shared memory
uint32_t shared_mem_bytes = 1024;
cudaError_t result = kernel.Launch(args, grid, block, stream, shared_mem_bytes);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream, shared_mem_bytes));

Integration with Different Compilers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {

// Launch kernel
tvm::ffi::cuda_api::ResultType result = g_add_one_kernel->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you could combine the kernel launch on line 87 and this error check into a single line. This would also remove the need for the result variable. Other parts of this PR follow this more concise pattern (e.g., in docs/guides/cubin_launcher.rst).

For example:

TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(g_add_one_kernel->Launch(args, grid, block, stream));

This would replace lines 87 and 88.

}

} // namespace cubin_dynamic
Expand Down Expand Up @@ -125,7 +125,7 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {

// Launch kernel
tvm::ffi::cuda_api::ResultType result = g_mul_two_kernel->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to my other comment, you can combine the kernel launch on line 127 and this error check into one line for better readability and conciseness.

TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(g_mul_two_kernel->Launch(args, grid, block, stream));

This would replace lines 127 and 128.

}

// Export TVM-FFI functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {

// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve conciseness, consider combining the kernel launch on line 75 with this error check. This pattern is used in other files in this PR, for example in examples/cubin_launcher/example_nvrtc_cubin.py.

TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));

}

} // namespace cubin_embedded
Expand Down Expand Up @@ -112,7 +112,7 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {

// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to my other comment, you can combine the kernel launch on line 114 and this error check into a single statement to make the code more concise.

TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));

}

// Export TVM-FFI functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {

// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you could combine the kernel launch on line 72 and this error check into a single line. This would also remove the need for the result variable. Other parts of this PR follow this pattern.

TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));

}

} // namespace cubin_embedded
Expand Down Expand Up @@ -109,7 +109,7 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {

// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to my other comment, you can combine the kernel launch on line 111 and this error check into one line for better readability and conciseness.

TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));

}

// Export TVM-FFI functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {

// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve conciseness, consider combining the kernel launch on line 72 with this error check. This pattern is used in other files in this PR.

TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));

}

} // namespace cubin_embedded
Expand Down Expand Up @@ -109,7 +109,7 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {

// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to my other comment, you can combine the kernel launch on line 111 and this error check into a single statement to make the code more concise.

TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));

}

// Export TVM-FFI functions
Expand Down
6 changes: 2 additions & 4 deletions examples/cubin_launcher/example_nvrtc_cubin.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def use_cubin_kernel(cubin_bytes: bytes) -> int:
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));
}

void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
Expand Down Expand Up @@ -158,8 +157,7 @@ def use_cubin_kernel(cubin_bytes: bytes) -> int:
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

// Launch kernel
cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));
}

} // namespace nvrtc_loader
Expand Down
3 changes: 1 addition & 2 deletions examples/cubin_launcher/example_triton_cubin.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ def use_cubin_kernel(cubin_bytes: bytes) -> int:
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

cudaError_t result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));
}

} // namespace triton_loader
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/ffi/extra/cuda/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,31 @@
#ifndef TVM_FFI_EXTRA_CUDA_BASE_H_
#define TVM_FFI_EXTRA_CUDA_BASE_H_

#include <cuda_runtime.h>
#include <tvm/ffi/error.h>

namespace tvm {
namespace ffi {

/*!
* \brief Macro for checking CUDA runtime API errors.
*
* This macro checks the return value of CUDA runtime API calls and throws
* a RuntimeError with detailed error information if the call fails.
*
* \param stmt The CUDA runtime API call to check.
*/
#define TVM_FFI_CHECK_CUDA_ERROR(stmt) \
do { \
cudaError_t __err = (stmt); \
if (__err != cudaSuccess) { \
const char* __err_name = cudaGetErrorName(__err); \
const char* __err_str = cudaGetErrorString(__err); \
TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << " (" \
<< static_cast<int>(__err) << "): " << __err_str; \
} \
} while (0)

/*!
* \brief A simple 3D dimension type for CUDA kernel launch configuration.
*
Expand Down
17 changes: 8 additions & 9 deletions include/tvm/ffi/extra/cuda/cubin_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#ifndef TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
#define TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_

#include <cuda.h>
#include <cuda.h> // NOLINT(clang-diagnostic-error)
#include <cuda_runtime.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
Expand Down Expand Up @@ -234,7 +234,7 @@ namespace ffi {
* TVMFFIEnvGetStream(device.device_type, device.device_id));
*
* cudaError_t result = kernel.Launch(args, grid, block, stream);
* TVM_FFI_CHECK_CUDA_ERROR(result);
* TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
* }
* \endcode
*
Expand Down Expand Up @@ -295,7 +295,7 @@ class CubinModule {
* \param bytes CUBIN binary data as a Bytes object.
*/
explicit CubinModule(const Bytes& bytes) {
TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, bytes.data()));
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuda_api::LoadLibrary(&library_, bytes.data()));
}

/*!
Expand All @@ -305,7 +305,7 @@ class CubinModule {
* \note The `code` buffer points to an ELF image.
*/
explicit CubinModule(const char* code) {
TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, code));
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuda_api::LoadLibrary(&library_, code));
}

/*!
Expand All @@ -315,7 +315,7 @@ class CubinModule {
* \note The `code` buffer points to an ELF image.
*/
explicit CubinModule(const unsigned char* code) {
TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, code));
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuda_api::LoadLibrary(&library_, code));
}

/*! \brief Destructor unloads the library */
Expand Down Expand Up @@ -418,7 +418,7 @@ class CubinModule {
* // Launch on stream
* cudaStream_t stream = ...;
* cudaError_t result = kernel.Launch(args, grid, block, stream);
* TVM_FFI_CHECK_CUDA_ERROR(result);
* TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
* \endcode
*
* \note This class is movable but not copyable.
Expand All @@ -434,7 +434,7 @@ class CubinKernel {
* \param name Name of the kernel function.
*/
CubinKernel(cuda_api::LibraryHandle library, const char* name) {
TVM_FFI_CHECK_CUDA_ERROR(cuda_api::GetKernel(&kernel_, library, name));
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuda_api::GetKernel(&kernel_, library, name));
}

/*! \brief Destructor (kernel handle doesn't need explicit cleanup) */
Expand Down Expand Up @@ -464,8 +464,7 @@ class CubinKernel {
* \par Error Checking
* Always check the returned cudaError_t:
* \code{.cpp}
* cudaError_t result = kernel.Launch(args, grid, block, stream);
* TVM_FFI_CHECK_CUDA_ERROR(result);
* TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block, stream));
* \endcode
*
* \param args Array of pointers to kernel arguments (must point to actual values).
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ffi/extra/cuda/device_guard.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#ifndef TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_
#define TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_

#include <tvm/ffi/extra/cuda/internal/unified_api.h>
#include <tvm/ffi/extra/cuda/base.h>

namespace tvm {
namespace ffi {
Expand Down
8 changes: 5 additions & 3 deletions include/tvm/ffi/extra/cuda/internal/unified_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ using DeviceAttrType = CUdevice_attribute;
constexpr ResultType kSuccess = CUDA_SUCCESS;

// Driver API Functions
#define _TVM_FFI_CUDA_FUNC(name) cu##name
#define _TVM_FFI_CUDA_FUNC(name) cu##name // NOLINT(bugprone-reserved-identifier)

#else

Expand Down Expand Up @@ -110,7 +110,9 @@ inline void GetErrorString(ResultType err, const char** name, const char** str)
#endif
}

#define TVM_FFI_CHECK_CUDA_ERROR(stmt) \
// this macro is only used to check cuda errors in cubin launcher where
// we might switch between driver and runtime API.
#define TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(stmt) \
do { \
::tvm::ffi::cuda_api::ResultType __err = (stmt); \
if (__err != ::tvm::ffi::cuda_api::kSuccess) { \
Expand Down Expand Up @@ -143,7 +145,7 @@ inline DeviceHandle GetDeviceHandle(int device_id) {
CUdevice dev;
// Note: We use CHECK here because this conversion usually shouldn't fail if ID is valid
// and we need to return a value.
TVM_FFI_CHECK_CUDA_ERROR(cuDeviceGet(&dev, device_id));
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuDeviceGet(&dev, device_id));
return dev;
#else
return device_id;
Expand Down
8 changes: 4 additions & 4 deletions tests/python/test_cubin_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def test_cubin_launcher_add_one() -> None:
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

cudaError_t result = g_kernel_add_one->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
auto result = g_kernel_add_one->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Comment on lines +161 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you can combine these two lines into one, as has been done in other example files in this PR. This also removes the need for the result variable.

Suggested change
auto result = g_kernel_add_one->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(g_kernel_add_one->Launch(args, grid, block, stream));

}

void LaunchMulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
Expand All @@ -184,8 +184,8 @@ def test_cubin_launcher_add_one() -> None:
DLDevice device = x.device();
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

cudaError_t result = g_kernel_mul_two->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_ERROR(result);
auto result = g_kernel_mul_two->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
Comment on lines +187 to +188
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to my other comment, these two lines can be combined into a single statement for better readability and conciseness.

Suggested change
auto result = g_kernel_mul_two->Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(g_kernel_mul_two->Launch(args, grid, block, stream));

}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(load_cubin_data, cubin_test::LoadCubinData);
Expand Down