1- import torch
2- import torchaudio
31import os
42import re
53import tempfile
4+ import numpy as np
5+ from typing import Generator
66from huggingface_hub import snapshot_download
77from .model .voxcpm import VoxCPMModel
88
@@ -11,6 +11,7 @@ def __init__(self,
1111 voxcpm_model_path : str ,
1212 zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base" ,
1313 enable_denoiser : bool = True ,
14+ optimize : bool = True ,
1415 ):
1516 """Initialize VoxCPM TTS pipeline.
1617
@@ -21,9 +22,10 @@ def __init__(self,
2122 zipenhancer_model_path: ModelScope acoustic noise suppression model
2223 id or local path. If None, denoiser will not be initialized.
2324 enable_denoiser: Whether to initialize the denoiser pipeline.
25+ optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
2426 """
2527 print (f"voxcpm_model_path: { voxcpm_model_path } , zipenhancer_model_path: { zipenhancer_model_path } , enable_denoiser: { enable_denoiser } " )
26- self .tts_model = VoxCPMModel .from_local (voxcpm_model_path )
28+ self .tts_model = VoxCPMModel .from_local (voxcpm_model_path , optimize = optimize )
2729 self .text_normalizer = None
2830 if enable_denoiser and zipenhancer_model_path is not None :
2931 from .zipenhancer import ZipEnhancer
@@ -43,6 +45,7 @@ def from_pretrained(cls,
4345 zipenhancer_model_id : str = "iic/speech_zipenhancer_ans_multiloss_16k_base" ,
4446 cache_dir : str = None ,
4547 local_files_only : bool = False ,
48+ ** kwargs ,
4649 ):
4750 """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
4851
@@ -54,6 +57,8 @@ def from_pretrained(cls,
5457 cache_dir: Custom cache directory for the snapshot.
5558 local_files_only: If True, only use local files and do not attempt
5659 to download.
60+ Kwargs:
61+ Additional keyword arguments passed to the ``VoxCPM`` constructor.
5762
5863 Returns:
5964 VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
@@ -82,9 +87,16 @@ def from_pretrained(cls,
8287 voxcpm_model_path = local_path ,
8388 zipenhancer_model_path = zipenhancer_model_id if load_denoiser else None ,
8489 enable_denoiser = load_denoiser ,
90+ ** kwargs ,
8591 )
8692
87- def generate (self ,
93+ def generate (self , * args , ** kwargs ) -> np .ndarray :
94+ return next (self ._generate (* args , streaming = False , ** kwargs ))
95+
96+ def generate_streaming (self , * args , ** kwargs ) -> Generator [np .ndarray , None , None ]:
97+ return self ._generate (* args , streaming = True , ** kwargs )
98+
99+ def _generate (self ,
88100 text : str ,
89101 prompt_wav_path : str = None ,
90102 prompt_text : str = None ,
@@ -96,7 +108,8 @@ def generate(self,
96108 retry_badcase : bool = True ,
97109 retry_badcase_max_times : int = 3 ,
98110 retry_badcase_ratio_threshold : float = 6.0 ,
99- ):
111+ streaming : bool = False ,
112+ ) -> Generator [np .ndarray , None , None ]:
100113 """Synthesize speech for the given text and return a single waveform.
101114
102115 This method optionally builds and reuses a prompt cache. If an external
@@ -118,8 +131,11 @@ def generate(self,
118131 retry_badcase: Whether to retry badcase.
119132 retry_badcase_max_times: Maximum number of times to retry badcase.
120133 retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
134+ streaming: Whether to return a generator of audio chunks.
121135 Returns:
122- numpy.ndarray: 1D waveform array (float32) on CPU.
136+ Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
137+ Yields audio chunks for each generations step if ``streaming=True``,
138+ otherwise yields a single array containing the final audio.
123139 """
124140 if not text .strip () or not isinstance (text , str ):
125141 raise ValueError ("target text must be a non-empty string" )
@@ -155,7 +171,7 @@ def generate(self,
155171 self .text_normalizer = TextNormalizer ()
156172 text = self .text_normalizer .normalize (text )
157173
158- wav , target_text_token , generated_audio_feat = self .tts_model .generate_with_prompt_cache (
174+ generate_result = self .tts_model ._generate_with_prompt_cache (
159175 target_text = text ,
160176 prompt_cache = fixed_prompt_cache ,
161177 min_len = 2 ,
@@ -165,9 +181,11 @@ def generate(self,
165181 retry_badcase = retry_badcase ,
166182 retry_badcase_max_times = retry_badcase_max_times ,
167183 retry_badcase_ratio_threshold = retry_badcase_ratio_threshold ,
184+ streaming = streaming ,
168185 )
169186
170- return wav .squeeze (0 ).cpu ().numpy ()
187+ for wav , _ , _ in generate_result :
188+ yield wav .squeeze (0 ).cpu ().numpy ()
171189
172190 finally :
173191 if temp_prompt_wav_path and os .path .exists (temp_prompt_wav_path ):
0 commit comments