Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
import sentencepiece
import sphn
import torch
import os

# Disable PyTorch's dynamic compilation (dynamo)
# This is needed because the quantized model uses custom operations
# that aren't compatible with dynamo
import torch._dynamo
torch._dynamo.config.suppress_errors = True

from .client_utils import log
from .models import loaders, MimiModel, LMModel, LMGen
from .run_inference import get_condition_tensors
Expand Down
56 changes: 53 additions & 3 deletions moshi/moshi/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,62 @@ def __init__(self, linear: nn.Linear):
weight = linear.weight
assert weight.data.dtype.is_floating_point
assert linear.bias is None
CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
self.weight = nn.Parameter(CB, requires_grad=False)
self.weight_scb = nn.Parameter(SCB, requires_grad=False)

# Check if the weight is on a meta device
if weight.device.type == 'meta':
# For meta device, we need to preserve the shape information
# Create tensors with the same shape as the original weight
# We'll use the shape information from the linear layer
out_features, in_features = weight.shape

# Create CB tensor with shape [out_features * 8/8, in_features]
# The first dimension is rounded up to a multiple of 8
# This matches the shape that would be produced by int8_vectorwise_quant
padded_out_features = ((out_features + 7) // 8) * 8
self.weight = nn.Parameter(
torch.zeros((padded_out_features, in_features),
dtype=torch.int8, device='meta'),
requires_grad=False
)

# Create SCB tensor with shape [out_features]
self.weight_scb = nn.Parameter(
torch.zeros(out_features, dtype=torch.float, device='meta'),
requires_grad=False
)
self.is_meta = True
else:
# Normal quantization for non-meta tensors
CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
self.weight = nn.Parameter(CB, requires_grad=False)
self.weight_scb = nn.Parameter(SCB, requires_grad=False)
self.is_meta = False

def _check_meta_status(self):
"""Check if the weights are still meta tensors and update is_meta flag accordingly."""
if hasattr(self, 'is_meta') and self.is_meta:
# Check if the weights have been loaded (no longer on meta device)
if self.weight.device.type != 'meta' and self.weight.numel() > 0:
self.is_meta = False

# Ensure the scale tensor is float32, regardless of the model's dtype
if self.weight_scb.dtype != torch.float:
self.weight_scb.data = self.weight_scb.data.float()

def forward(self, x):
import bitsandbytes as bnb # type: ignore

# Update meta status based on actual tensor properties
self._check_meta_status()

# Check if this is a meta tensor that hasn't been properly initialized yet
if hasattr(self, 'is_meta') and self.is_meta:
# If we're still in meta mode but trying to do a forward pass,
# this means the weights weren't properly loaded
raise RuntimeError(
"Attempting to run forward pass with meta tensors. "
"The model weights need to be loaded before running inference.")

state = bnb.MatmulLtState()
state.CB = self.weight # type: ignore
assert isinstance(state.CB, torch.Tensor)
Expand Down