Skip to content

Commit fbf8984

Browse files
author
刘鑫
committed
Merge branch 'main' into dev
2 parents 961569e + 41752dc commit fbf8984

File tree

5 files changed

+160
-57
lines changed

5 files changed

+160
-57
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
launch.json
2+
__pycache__
3+
voxcpm.egg-info

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@ By default, when you first run the script, the model will be downloaded automati
6262
### 2. Basic Usage
6363
```python
6464
import soundfile as sf
65+
import numpy as np
6566
from voxcpm import VoxCPM
6667
6768
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
6869
70+
# Non-streaming
6971
wav = model.generate(
7072
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
7173
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
@@ -81,6 +83,18 @@ wav = model.generate(
8183
8284
sf.write("output.wav", wav, 16000)
8385
print("saved: output.wav")
86+
87+
# Streaming
88+
chunks = []
89+
for chunk in model.generate_streaming(
90+
text = "Streaming text to speech is easy with VoxCPM!",
91+
# supports same args as above
92+
):
93+
chunks.append(chunk)
94+
wav = np.concatenate(chunks)
95+
96+
sf.write("output_streaming.wav", wav, 16000)
97+
print("saved: output_streaming.wav")
8498
```
8599

86100
### 3. CLI Usage

pyproject.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@ classifiers = [
2020
"Intended Audience :: Developers",
2121
"Operating System :: OS Independent",
2222
"Programming Language :: Python :: 3",
23-
"Programming Language :: Python :: 3.8",
24-
"Programming Language :: Python :: 3.9",
2523
"Programming Language :: Python :: 3.10",
2624
"Programming Language :: Python :: 3.11",
2725
]
28-
requires-python = ">=3.8"
26+
requires-python = ">=3.10"
2927
dependencies = [
3028
"torch>=2.5.0",
3129
"torchaudio>=2.5.0",
@@ -78,7 +76,7 @@ version_scheme = "post-release"
7876

7977
[tool.black]
8078
line-length = 120
81-
target-version = ['py38']
79+
target-version = ['py310']
8280
include = '\.pyi?$'
8381
extend-exclude = '''
8482
/(

src/voxcpm/core.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import torch
2-
import torchaudio
31
import os
42
import re
53
import tempfile
4+
import numpy as np
5+
from typing import Generator
66
from huggingface_hub import snapshot_download
77
from .model.voxcpm import VoxCPMModel
88

@@ -11,6 +11,7 @@ def __init__(self,
1111
voxcpm_model_path : str,
1212
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
1313
enable_denoiser : bool = True,
14+
optimize: bool = True,
1415
):
1516
"""Initialize VoxCPM TTS pipeline.
1617
@@ -21,9 +22,10 @@ def __init__(self,
2122
zipenhancer_model_path: ModelScope acoustic noise suppression model
2223
id or local path. If None, denoiser will not be initialized.
2324
enable_denoiser: Whether to initialize the denoiser pipeline.
25+
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
2426
"""
2527
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
26-
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
28+
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
2729
self.text_normalizer = None
2830
if enable_denoiser and zipenhancer_model_path is not None:
2931
from .zipenhancer import ZipEnhancer
@@ -43,6 +45,7 @@ def from_pretrained(cls,
4345
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
4446
cache_dir: str = None,
4547
local_files_only: bool = False,
48+
**kwargs,
4649
):
4750
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
4851
@@ -54,6 +57,8 @@ def from_pretrained(cls,
5457
cache_dir: Custom cache directory for the snapshot.
5558
local_files_only: If True, only use local files and do not attempt
5659
to download.
60+
Kwargs:
61+
Additional keyword arguments passed to the ``VoxCPM`` constructor.
5762
5863
Returns:
5964
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
@@ -82,9 +87,16 @@ def from_pretrained(cls,
8287
voxcpm_model_path=local_path,
8388
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
8489
enable_denoiser=load_denoiser,
90+
**kwargs,
8591
)
8692

87-
def generate(self,
93+
def generate(self, *args, **kwargs) -> np.ndarray:
94+
return next(self._generate(*args, streaming=False, **kwargs))
95+
96+
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
97+
return self._generate(*args, streaming=True, **kwargs)
98+
99+
def _generate(self,
88100
text : str,
89101
prompt_wav_path : str = None,
90102
prompt_text : str = None,
@@ -96,7 +108,8 @@ def generate(self,
96108
retry_badcase : bool = True,
97109
retry_badcase_max_times : int = 3,
98110
retry_badcase_ratio_threshold : float = 6.0,
99-
):
111+
streaming: bool = False,
112+
) -> Generator[np.ndarray, None, None]:
100113
"""Synthesize speech for the given text and return a single waveform.
101114
102115
This method optionally builds and reuses a prompt cache. If an external
@@ -118,8 +131,11 @@ def generate(self,
118131
retry_badcase: Whether to retry badcase.
119132
retry_badcase_max_times: Maximum number of times to retry badcase.
120133
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
134+
streaming: Whether to return a generator of audio chunks.
121135
Returns:
122-
numpy.ndarray: 1D waveform array (float32) on CPU.
136+
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
137+
Yields audio chunks for each generations step if ``streaming=True``,
138+
otherwise yields a single array containing the final audio.
123139
"""
124140
if not text.strip() or not isinstance(text, str):
125141
raise ValueError("target text must be a non-empty string")
@@ -155,7 +171,7 @@ def generate(self,
155171
self.text_normalizer = TextNormalizer()
156172
text = self.text_normalizer.normalize(text)
157173

158-
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
174+
generate_result = self.tts_model._generate_with_prompt_cache(
159175
target_text=text,
160176
prompt_cache=fixed_prompt_cache,
161177
min_len=2,
@@ -165,9 +181,11 @@ def generate(self,
165181
retry_badcase=retry_badcase,
166182
retry_badcase_max_times=retry_badcase_max_times,
167183
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
184+
streaming=streaming,
168185
)
169186

170-
return wav.squeeze(0).cpu().numpy()
187+
for wav, _, _ in generate_result:
188+
yield wav.squeeze(0).cpu().numpy()
171189

172190
finally:
173191
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):

0 commit comments

Comments
 (0)