diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 4d91a728..af88c95b 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -20,6 +20,14 @@ import sentencepiece import sphn import torch +import os + +# Disable PyTorch's dynamic compilation (dynamo) +# This is needed because the quantized model uses custom operations +# that aren't compatible with dynamo +import torch._dynamo +torch._dynamo.config.suppress_errors = True + from .client_utils import log from .models import loaders, MimiModel, LMModel, LMGen from .run_inference import get_condition_tensors diff --git a/moshi/moshi/utils/quantize.py b/moshi/moshi/utils/quantize.py index 9682d5b6..4bfc7e05 100644 --- a/moshi/moshi/utils/quantize.py +++ b/moshi/moshi/utils/quantize.py @@ -22,12 +22,62 @@ def __init__(self, linear: nn.Linear): weight = linear.weight assert weight.data.dtype.is_floating_point assert linear.bias is None - CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore - self.weight = nn.Parameter(CB, requires_grad=False) - self.weight_scb = nn.Parameter(SCB, requires_grad=False) + + # Check if the weight is on a meta device + if weight.device.type == 'meta': + # For meta device, we need to preserve the shape information + # Create tensors with the same shape as the original weight + # We'll use the shape information from the linear layer + out_features, in_features = weight.shape + + # Create CB tensor with shape [out_features * 8/8, in_features] + # The first dimension is rounded up to a multiple of 8 + # This matches the shape that would be produced by int8_vectorwise_quant + padded_out_features = ((out_features + 7) // 8) * 8 + self.weight = nn.Parameter( + torch.zeros((padded_out_features, in_features), + dtype=torch.int8, device='meta'), + requires_grad=False + ) + + # Create SCB tensor with shape [out_features] + self.weight_scb = nn.Parameter( + torch.zeros(out_features, dtype=torch.float, device='meta'), + requires_grad=False + ) + self.is_meta = True + else: + # Normal quantization for non-meta tensors + CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore + self.weight = nn.Parameter(CB, requires_grad=False) + self.weight_scb = nn.Parameter(SCB, requires_grad=False) + self.is_meta = False + + def _check_meta_status(self): + """Check if the weights are still meta tensors and update is_meta flag accordingly.""" + if hasattr(self, 'is_meta') and self.is_meta: + # Check if the weights have been loaded (no longer on meta device) + if self.weight.device.type != 'meta' and self.weight.numel() > 0: + self.is_meta = False + + # Ensure the scale tensor is float32, regardless of the model's dtype + if self.weight_scb.dtype != torch.float: + self.weight_scb.data = self.weight_scb.data.float() def forward(self, x): import bitsandbytes as bnb # type: ignore + + # Update meta status based on actual tensor properties + self._check_meta_status() + + # Check if this is a meta tensor that hasn't been properly initialized yet + if hasattr(self, 'is_meta') and self.is_meta: + # If we're still in meta mode but trying to do a forward pass, + # this means the weights weren't properly loaded + raise RuntimeError( + "Attempting to run forward pass with meta tensors. " + "The model weights need to be loaded before running inference.") + state = bnb.MatmulLtState() state.CB = self.weight # type: ignore assert isinstance(state.CB, torch.Tensor)