diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index 5fcb927d4..2b61e3247 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -2,7 +2,16 @@ from bitsandbytes.triton.triton_utils import is_triton_available -if not is_triton_available(): +use_triton = False +if is_triton_available(): + import triton + try: + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + use_triton = True + except ImportError: + print("bitsandbytes matmul_mixed warning, triton.ops.matmul_perf_model is not available anymore.") + +if use_triton == False: def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index 05e30a4c9..2b0c31ef9 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -2,7 +2,16 @@ from bitsandbytes.triton.triton_utils import is_triton_available -if not is_triton_available(): +use_triton = False +if is_triton_available(): + import triton + try: + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + use_triton = True + except ImportError: + print("bitsandbytes rowwise_dequantize warning, triton.ops.matmul_perf_model is not available anymore.") + +if use_triton == False: def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None