Skip to content

Commit

Permalink
bitsandbytes, test existence of triton.ops.matmul_perf_model
Browse files Browse the repository at this point in the history
This fixes the bitsandbytes to work with newer triton versions
which does not anymore include the triton.ops.matmul_perf_model
which has been moved to triton-langs kernels project.
https://github.com/triton-lang/kernels

Fix is to simply catch the import error if triton is available
but the triton.ops.matmul_perf_model is not. In such case the logic
is implemented in a same way than earlier when the triton
was not available at all.

To my understanding there is currently no a way to install
opt_matmul_perf.py from new tritons kernel project directly with pip,
so in future it could propably be good to include that python class
directly in bitsandbytes.

bitsandbytes-foundation#1492

Signed-off-by: Mika Laitio <[email protected]>
  • Loading branch information
lamikr committed Jan 26, 2025
1 parent 8cd7793 commit b64adaa
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
11 changes: 10 additions & 1 deletion bitsandbytes/triton/int8_matmul_mixed_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():
use_triton = False
if is_triton_available():
import triton
try:
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
use_triton = True
except ImportError:
print("bitsandbytes matmul_mixed warning, triton.ops.matmul_perf_model is not available anymore.")

if use_triton == False:

def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
return None
Expand Down
11 changes: 10 additions & 1 deletion bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():
use_triton = False
if is_triton_available():
import triton
try:
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
use_triton = True
except ImportError:
print("bitsandbytes rowwise_dequantize warning, triton.ops.matmul_perf_model is not available anymore.")

if use_triton == False:

def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
return None
Expand Down

0 comments on commit b64adaa

Please sign in to comment.