@@ -85,11 +85,15 @@ def __init__(
8585 self .patch_size = config .patch_size
8686 self .device = config .device
8787 if not torch .cuda .is_available ():
88- self .device = "cpu"
88+ if torch .backends .mps .is_available ():
89+ self .device = "mps"
90+ else :
91+ self .device = "cpu"
92+ print (f"Running on device: { self .device } , dtype: { self .config .dtype } " )
8993
9094 # Text-Semantic LM
9195 self .base_lm = MiniCPMModel (config .lm_config )
92- self .base_lm .setup_cache (1 , config .max_length , self .device , get_dtype (config .dtype ))
96+ self .base_lm .setup_cache (1 , config .max_length , self .device , get_dtype (self . config .dtype ))
9397
9498 self .text_tokenizer = mask_multichar_chinese_tokens (tokenizer )
9599 self .audio_start_token = 101
@@ -100,7 +104,7 @@ def __init__(
100104 residual_lm_config .num_hidden_layers = config .residual_lm_num_layers
101105 residual_lm_config .vocab_size = 0
102106 self .residual_lm = MiniCPMModel (residual_lm_config )
103- self .residual_lm .setup_cache (1 , config .max_length , self .device , get_dtype (config .dtype ))
107+ self .residual_lm .setup_cache (1 , config .max_length , self .device , get_dtype (self . config .dtype ))
104108
105109 # Local Encoder
106110 encoder_config = config .lm_config .model_copy (deep = True )
@@ -132,7 +136,7 @@ def __init__(
132136 config .lm_config .hidden_size ,
133137 config .scalar_quantization_latent_dim ,
134138 config .scalar_quantization_scale
135- )
139+ )
136140 self .enc_to_lm_proj = nn .Linear (config .encoder_config .hidden_dim , config .lm_config .hidden_size )
137141 self .lm_to_dit_proj = nn .Linear (config .lm_config .hidden_size , config .dit_config .hidden_dim )
138142 self .res_to_dit_proj = nn .Linear (config .lm_config .hidden_size , config .dit_config .hidden_dim )
@@ -271,7 +275,7 @@ def _generate(
271275
272276 text_token = text_token .unsqueeze (0 ).to (self .device )
273277 text_mask = text_mask .unsqueeze (0 ).to (self .device )
274- audio_feat = audio_feat .unsqueeze (0 ).to (self .device ).to (torch . bfloat16 )
278+ audio_feat = audio_feat .unsqueeze (0 ).to (self .device ).to (get_dtype ( self . config . dtype ) )
275279 audio_mask = audio_mask .unsqueeze (0 ).to (self .device )
276280
277281 target_text_length = len (self .text_tokenizer (target_text ))
@@ -484,7 +488,7 @@ def _generate_with_prompt_cache(
484488
485489 text_token = text_token .unsqueeze (0 ).to (self .device )
486490 text_mask = text_mask .unsqueeze (0 ).to (self .device )
487- audio_feat = audio_feat .unsqueeze (0 ).to (self .device ).to (torch . bfloat16 )
491+ audio_feat = audio_feat .unsqueeze (0 ).to (self .device ).to (get_dtype ( self . config . dtype ) )
488492 audio_mask = audio_mask .unsqueeze (0 ).to (self .device )
489493
490494 # run inference
@@ -670,7 +674,7 @@ def from_local(cls, path: str, optimize: bool = True):
670674 )["state_dict" ]
671675
672676 model = cls (config , tokenizer , audio_vae )
673- lm_dtype = get_dtype (config .dtype )
677+ lm_dtype = get_dtype (model . config .dtype )
674678 model = model .to (lm_dtype )
675679 model .audio_vae = model .audio_vae .to (torch .float32 )
676680
0 commit comments