Skip to content

Commit 3b02ebf

Browse files
committed
Fixes quantizing issue
1 parent 62d0154 commit 3b02ebf

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

moshi/moshi/server.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@
2020
import sentencepiece
2121
import sphn
2222
import torch
23+
import os
24+
25+
# Disable PyTorch's dynamic compilation (dynamo)
26+
# This is needed because the quantized model uses custom operations
27+
# that aren't compatible with dynamo
28+
import torch._dynamo
29+
torch._dynamo.config.suppress_errors = True
30+
31+
# Disable CUDA graph capture
32+
# This is needed because quantized models may use operations
33+
# that aren't compatible with CUDA graph capture
34+
os.environ["NO_CUDA_GRAPH"] = "1"
35+
2336
from .client_utils import log
2437
from .models import loaders, MimiModel, LMModel, LMGen
2538
from .run_inference import get_condition_tensors

moshi/moshi/utils/quantize.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,62 @@ def __init__(self, linear: nn.Linear):
2222
weight = linear.weight
2323
assert weight.data.dtype.is_floating_point
2424
assert linear.bias is None
25-
CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
26-
self.weight = nn.Parameter(CB, requires_grad=False)
27-
self.weight_scb = nn.Parameter(SCB, requires_grad=False)
25+
26+
# Check if the weight is on a meta device
27+
if weight.device.type == 'meta':
28+
# For meta device, we need to preserve the shape information
29+
# Create tensors with the same shape as the original weight
30+
# We'll use the shape information from the linear layer
31+
out_features, in_features = weight.shape
32+
33+
# Create CB tensor with shape [out_features * 8/8, in_features]
34+
# The first dimension is rounded up to a multiple of 8
35+
# This matches the shape that would be produced by int8_vectorwise_quant
36+
padded_out_features = ((out_features + 7) // 8) * 8
37+
self.weight = nn.Parameter(
38+
torch.zeros((padded_out_features, in_features),
39+
dtype=torch.int8, device='meta'),
40+
requires_grad=False
41+
)
42+
43+
# Create SCB tensor with shape [out_features]
44+
self.weight_scb = nn.Parameter(
45+
torch.zeros(out_features, dtype=torch.float, device='meta'),
46+
requires_grad=False
47+
)
48+
self.is_meta = True
49+
else:
50+
# Normal quantization for non-meta tensors
51+
CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
52+
self.weight = nn.Parameter(CB, requires_grad=False)
53+
self.weight_scb = nn.Parameter(SCB, requires_grad=False)
54+
self.is_meta = False
55+
56+
def _check_meta_status(self):
57+
"""Check if the weights are still meta tensors and update is_meta flag accordingly."""
58+
if hasattr(self, 'is_meta') and self.is_meta:
59+
# Check if the weights have been loaded (no longer on meta device)
60+
if self.weight.device.type != 'meta' and self.weight.numel() > 0:
61+
self.is_meta = False
62+
63+
# Ensure the scale tensor is float32, regardless of the model's dtype
64+
if self.weight_scb.dtype != torch.float:
65+
self.weight_scb.data = self.weight_scb.data.float()
2866

2967
def forward(self, x):
3068
import bitsandbytes as bnb # type: ignore
69+
70+
# Update meta status based on actual tensor properties
71+
self._check_meta_status()
72+
73+
# Check if this is a meta tensor that hasn't been properly initialized yet
74+
if hasattr(self, 'is_meta') and self.is_meta:
75+
# If we're still in meta mode but trying to do a forward pass,
76+
# this means the weights weren't properly loaded
77+
raise RuntimeError(
78+
"Attempting to run forward pass with meta tensors. "
79+
"The model weights need to be loaded before running inference.")
80+
3181
state = bnb.MatmulLtState()
3282
state.CB = self.weight # type: ignore
3383
assert isinstance(state.CB, torch.Tensor)

0 commit comments

Comments
 (0)