Skip to content

Commit

Permalink
Update create_dynamic_map to always return a float32 tensor (#1521)
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchellgoffpc authored Feb 19, 2025
1 parent 86b6c37 commit 8ed7d97
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,14 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
if signed
else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1,
)
boundaries = torch.linspace(0.1, 1, fraction_items)
boundaries = torch.linspace(0.1, 1, fraction_items, dtype=torch.float32)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()

if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items + 1)
boundaries = torch.linspace(0.1, 1, additional_items + 1, dtype=torch.float32)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
Expand All @@ -412,7 +412,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.append(0)

data.sort()
return torch.tensor(data)
return torch.tensor(data, dtype=torch.float32)


def create_quantile_map(A, total_bits=8):
Expand Down

0 comments on commit 8ed7d97

Please sign in to comment.