Skip to content

Commit bf03ad4

Browse files
authored
feat: enable deepgemm jit for fp8 block-scale on SM90 (#1969)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Enable JIT compile for the FP8 DeepGEMM kernels, NVRTC is currently disabled it uses NVCC by default. ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * JIT include directory discovery now uses the flashinfer-python package instead of the previous package. * Updated resolved include path to the flashinfer data location. * Runtime compilation now consistently uses NVCC; the prior environment-variable toggle was removed. * Updated warning text when the expected package installation cannot be found. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duncan Moss <[email protected]>
1 parent 77091d4 commit bf03ad4

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

β€Žcsrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuhβ€Ž

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ std::vector<std::filesystem::path> getJitIncludeDirs() {
125125
static std::vector<std::filesystem::path> includeDirs;
126126
if (includeDirs.empty()) {
127127
// Command to execute
128-
char const* cmd = "pip show tensorrt_llm 2>/dev/null";
128+
char const* cmd = "pip show flashinfer-python 2>/dev/null";
129129

130130
// Buffer to store the output
131131
std::array<char, 128> buffer;
@@ -174,15 +174,11 @@ std::vector<std::filesystem::path> getJitIncludeDirs() {
174174
location.erase(location.find_last_not_of(" \n\r\t") + 1);
175175

176176
// Set the include directory based on the package location
177-
includeDirs.push_back(std::filesystem::path(location) / "tensorrt_llm" / "include");
178-
179-
if (!kJitUseNvcc) {
180-
includeDirs.push_back(std::filesystem::path(location) / "tensorrt_llm" / "include" /
181-
"cuda" / "include");
182-
}
177+
includeDirs.push_back(std::filesystem::path(location) / "flashinfer" / "data" / "csrc" /
178+
"nv_internal" / "tensorrt_llm");
183179
}
184180
} else {
185-
TLLM_LOG_WARNING("Failed to find TensorRT LLM installation, DeepGEMM will be disabled.");
181+
TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled.");
186182
}
187183
}
188184
return includeDirs;

β€Žcsrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuhβ€Ž

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,13 @@ static bool kJitDebugging = []() {
3636
}();
3737

3838
static bool kJitUseNvcc = []() {
39-
char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC");
40-
return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true");
39+
// char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC");
40+
// return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true");
41+
// always use nvcc
42+
// TODO: Enable nvrtc -- need these headers:
43+
// [TensorRT-LLM][INFO] Compilation log:
44+
// kernel.cu(16): catastrophic error: cannot open source file "cuda_bf16.h"
45+
return true;
4146
}();
4247

4348
static bool kJitDumpCubin = []() {

0 commit comments

Comments
Β (0)