Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable certain CUDA kernels to accept specified cuda stream #1330

Merged

Conversation

jeejeelee
Copy link
Contributor

@jeejeelee jeejeelee commented Aug 21, 2024

FIX #1308

By passing specified stream to certain kernel functions, cudagraph can correctly capture these kernels, enabling downstream repo vLLM to run inference in cudagraph mode, resulting in significant speed improvements for BNB models.
ping @matthewdouglas @Titus-von-Koeller @TimDettmers
cc @chenqianfzh

@Titus-von-Koeller
Copy link
Collaborator

Dear @jeejeelee,

Really cool, we weren't aware vLLM uses cudagraph. Just looked over this with Tim and overall, especially given the performance benefits this may have, this is a very strong contribution, thanks!

I checked out your branch and tried running the tests, but do get the below segfault, which doesn't happen on main. Rerunning the tests gives the same result. Could you please look into this, can you reproduce on your machine? I have a quad L4 setup with CC8.9, CUDA 12.4, Pytorch 2.4.

tests/test_autograd.py::test_matmul_fp8[matmul_fp8_mixed-fp16-transpose=FT-req_grad=TTT-dim4=61-dim3=59-dim2=43-dim1=17] Fatal Python error: Segmentation fault

Thread 0x00007deaa34006c0 (most recent call first):
<no Python frame>

Thread 0x00007dea95e006c0 (most recent call first):
<no Python frame>

Thread 0x00007deaa2a006c0 (most recent call first):
<no Python frame>

Thread 0x00007deaa3e006c0 (most recent call first):
<no Python frame>

Current thread 0x00007dec9d834740 (most recent call first):
  File "/home/ubuntu/src/bnb/bitsandbytes/functional.py", line 1535 in dequantize_no_absmax
  File "/home/ubuntu/src/bnb/bitsandbytes/functional.py", line 1476 in dequantize
  File "/home/ubuntu/src/bnb/bitsandbytes/research/autograd/_functions.py", line 42 in forward
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/torch/autograd/function.py", line 574 in apply
  File "/home/ubuntu/src/bnb/bitsandbytes/research/autograd/_functions.py", line 407 in matmul_fp8_mixed
  File "/home/ubuntu/src/bnb/tests/test_autograd.py", line 456 in test_matmul_fp8
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/python.py", line 159 in pytest_pyfunc_call
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/python.py", line 1627 in runtest
  File "/home/ubuntu/src/bnb/tests/conftest.py", line 9 in pytest_runtest_call
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 242 in <lambda>
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 341 in from_call
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 241 in call_and_report
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 132 in runtestprotocol
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/main.py", line 362 in pytest_runtestloop
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/main.py", line 337 in _main
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/main.py", line 283 in wrap_session
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/main.py", line 330 in pytest_cmdline_main
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/config/__init__.py", line 175 in main
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/config/__init__.py", line 201 in console_main
  File "/home/ubuntu/.condax/mamba/envs/bnb/bin/pytest", line 10 in <module>
Segmentation fault (core dumped)

Please also be sure to install the pre-commit hooks 🤗

@jeejeelee
Copy link
Contributor Author

@Titus-von-Koeller , Thank you for the feedback, I've corrected the error mentioned above. I'm verifying whether all the unit tests are passing.

@jeejeelee
Copy link
Contributor Author

On my machine with a 3090 GPU, my test results are as follows:

=========================================================== 13 failed, 3264 passed, 35 skipped, 1061 warnings, 16 errors in 875.66s (0:14:35) ==========================================================

All tests in test_generation.py failed due to a network connection error.
All tests in test_triton.py failed because my local triton version is 3.0.0.
The other errors are likely due to precision issues. I'm not certain if these are caused by this PR

@jeejeelee jeejeelee force-pushed the cuda-kernel-cudagraph branch from 3617b6e to 49ffcdc Compare August 21, 2024 17:15
@jeejeelee
Copy link
Contributor Author

@Titus-von-Koeller please review again, thanks~

@matthewdouglas
Copy link
Member

@danielhanchen I believe you're directly calling some of these C-API functions in Unsloth, so I want to make sure you've got a heads up here since this changes their signatures.

bitsandbytes/functional.py Outdated Show resolved Hide resolved
@matthewdouglas
Copy link
Member

@jeejeelee Thank you for the contribution! The only nit I have is the one that I noted about using c_void_p instead of uint64.

A few test failures in test_kbit_backprop and test_gemv_4bit is OK and not related to this PR. I see similar results on my 4090. The generation tests passed for me. Looks nice!

@danielhanchen
Copy link

@danielhanchen I believe you're directly calling some of these C-API functions in Unsloth, so I want to make sure you've got a heads up here since this changes their signatures.

Super thanks for the heads up!! Yep we use the C API directly!

@Titus-von-Koeller
Copy link
Collaborator

I'll be off until Monday, @matthewdouglas will be taking the lead. Thanks both!

@matthewdouglas matthewdouglas merged commit a685654 into bitsandbytes-foundation:main Aug 22, 2024
28 checks passed
matthewdouglas pushed a commit to matthewdouglas/bitsandbytes that referenced this pull request Oct 28, 2024
…ytes-foundation#1330)

* Done

* fix format

* fix format

* fix format

* fix format

* Address format error and fix default arg bug

* Refine stream argument passing mechanism

* Fix bug

* Delete unused code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

dequantize_4bit() gives wrong output when working in cuda graph mode
4 participants