Skip to content

Commit 2eb4d39

Browse files
author
刘鑫
committed
FX: Add MPS support
1 parent fbf8984 commit 2eb4d39

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

src/voxcpm/model/voxcpm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,15 @@ def __init__(
8585
self.patch_size = config.patch_size
8686
self.device = config.device
8787
if not torch.cuda.is_available():
88-
self.device = "cpu"
88+
if torch.backends.mps.is_available():
89+
self.device = "mps"
90+
else:
91+
self.device = "cpu"
92+
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
8993

9094
# Text-Semantic LM
9195
self.base_lm = MiniCPMModel(config.lm_config)
92-
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
96+
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
9397

9498
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
9599
self.audio_start_token = 101
@@ -100,7 +104,7 @@ def __init__(
100104
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
101105
residual_lm_config.vocab_size = 0
102106
self.residual_lm = MiniCPMModel(residual_lm_config)
103-
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
107+
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
104108

105109
# Local Encoder
106110
encoder_config = config.lm_config.model_copy(deep=True)
@@ -132,7 +136,7 @@ def __init__(
132136
config.lm_config.hidden_size,
133137
config.scalar_quantization_latent_dim,
134138
config.scalar_quantization_scale
135-
)
139+
)
136140
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
137141
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
138142
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -271,7 +275,7 @@ def _generate(
271275

272276
text_token = text_token.unsqueeze(0).to(self.device)
273277
text_mask = text_mask.unsqueeze(0).to(self.device)
274-
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
278+
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
275279
audio_mask = audio_mask.unsqueeze(0).to(self.device)
276280

277281
target_text_length = len(self.text_tokenizer(target_text))
@@ -484,7 +488,7 @@ def _generate_with_prompt_cache(
484488

485489
text_token = text_token.unsqueeze(0).to(self.device)
486490
text_mask = text_mask.unsqueeze(0).to(self.device)
487-
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
491+
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
488492
audio_mask = audio_mask.unsqueeze(0).to(self.device)
489493

490494
# run inference
@@ -670,7 +674,7 @@ def from_local(cls, path: str, optimize: bool = True):
670674
)["state_dict"]
671675

672676
model = cls(config, tokenizer, audio_vae)
673-
lm_dtype = get_dtype(config.dtype)
677+
lm_dtype = get_dtype(model.config.dtype)
674678
model = model.to(lm_dtype)
675679
model.audio_vae = model.audio_vae.to(torch.float32)
676680

src/voxcpm/modules/minicpm4/model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,12 @@ def forward(
153153
cos, sin = position_emb
154154

155155
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
156-
156+
157+
# ref: https://github.com/pytorch/pytorch/issues/163597
158+
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
159+
query_states = query_states.contiguous()
160+
key_states = key_states.contiguous()
161+
value_states = value_states.contiguous()
157162
attn_output = torch.nn.functional.scaled_dot_product_attention(
158163
query_states,
159164
key_states,
@@ -198,6 +203,11 @@ def forward_step(
198203

199204
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
200205

206+
# ref: https://github.com/pytorch/pytorch/issues/163597
207+
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
208+
query_states = query_states.contiguous()
209+
key_cache = key_cache.contiguous()
210+
value_cache = value_cache.contiguous()
201211
attn_output = torch.nn.functional.scaled_dot_product_attention(
202212
query_states,
203213
key_cache,

0 commit comments

Comments
 (0)