From b64adaaf899c199fa947c501b93eec77921b0bd9 Mon Sep 17 00:00:00 2001 From: Mika Laitio Date: Sat, 25 Jan 2025 21:00:51 -0800 Subject: [PATCH] bitsandbytes, test existence of triton.ops.matmul_perf_model This fixes the bitsandbytes to work with newer triton versions which does not anymore include the triton.ops.matmul_perf_model which has been moved to triton-langs kernels project. https://github.com/triton-lang/kernels Fix is to simply catch the import error if triton is available but the triton.ops.matmul_perf_model is not. In such case the logic is implemented in a same way than earlier when the triton was not available at all. To my understanding there is currently no a way to install opt_matmul_perf.py from new tritons kernel project directly with pip, so in future it could propably be good to include that python class directly in bitsandbytes. https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1492 Signed-off-by: Mika Laitio --- bitsandbytes/triton/int8_matmul_mixed_dequantize.py | 11 ++++++++++- bitsandbytes/triton/int8_matmul_rowwise_dequantize.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index 5fcb927d4..2b61e3247 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -2,7 +2,16 @@ from bitsandbytes.triton.triton_utils import is_triton_available -if not is_triton_available(): +use_triton = False +if is_triton_available(): + import triton + try: + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + use_triton = True + except ImportError: + print("bitsandbytes matmul_mixed warning, triton.ops.matmul_perf_model is not available anymore.") + +if use_triton == False: def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index 05e30a4c9..2b0c31ef9 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -2,7 +2,16 @@ from bitsandbytes.triton.triton_utils import is_triton_available -if not is_triton_available(): +use_triton = False +if is_triton_available(): + import triton + try: + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + use_triton = True + except ImportError: + print("bitsandbytes rowwise_dequantize warning, triton.ops.matmul_perf_model is not available anymore.") + +if use_triton == False: def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None