From 8ed7d97b4901b4f01f9ac04262e3a34d1cb6941a Mon Sep 17 00:00:00 2001 From: Mitchell Goff Date: Tue, 18 Feb 2025 16:36:41 -0800 Subject: [PATCH] Update create_dynamic_map to always return a float32 tensor (#1521) --- bitsandbytes/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a5cc4a9f0..595f824ba 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -389,14 +389,14 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1, ) - boundaries = torch.linspace(0.1, 1, fraction_items) + boundaries = torch.linspace(0.1, 1, fraction_items, dtype=torch.float32) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() if signed: data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items + 1) + boundaries = torch.linspace(0.1, 1, additional_items + 1, dtype=torch.float32) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() if signed: @@ -412,7 +412,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.append(0) data.sort() - return torch.tensor(data) + return torch.tensor(data, dtype=torch.float32) def create_quantile_map(A, total_bits=8):