From 63a761ff2bdbdd6b0ceeabb1e35fa668261a6a69 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 4 Jun 2024 20:55:00 -0700 Subject: [PATCH] CPU: add torch.compile for F.double_quant and F.quantize_4bit --- bitsandbytes/backends/cpu_xpu_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 138ec72f5..396234853 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -55,7 +55,7 @@ def _maybe_torch_compile(func): return func -# Don't use torch.compile for now due to PyTorch issue https://github.com/pytorch/pytorch/issues/124382 +@_maybe_torch_compile def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): """ Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. @@ -268,7 +268,7 @@ def mm_dequant_impl( } -# It's faster not to use torch.compile +@_maybe_torch_compile def quantize_4bit_impl( A: Tensor, absmax: Tensor = None,