diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 138ec72f5..396234853 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -55,7 +55,7 @@ def _maybe_torch_compile(func): return func -# Don't use torch.compile for now due to PyTorch issue https://github.com/pytorch/pytorch/issues/124382 +@_maybe_torch_compile def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): """ Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. @@ -268,7 +268,7 @@ def mm_dequant_impl( } -# It's faster not to use torch.compile +@_maybe_torch_compile def quantize_4bit_impl( A: Tensor, absmax: Tensor = None,