Skip to content

Commit 178da09

Browse files
author
Yuekai Zhang
committed
clean code
1 parent 5427c27 commit 178da09

File tree

6 files changed

+23
-29
lines changed

6 files changed

+23
-29
lines changed

runtime/triton_trtllm/client_grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ async def main():
692692
model_name=args.model_name,
693693
audio_save_dir=args.log_dir,
694694
padding_duration=10,
695-
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
695+
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
696696
chunk_overlap_duration=args.chunk_overlap_duration,
697697
)
698698
)

runtime/triton_trtllm/client_http.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def prepare_request(
162162
result = rsp.json()
163163
audio = result["outputs"][0]["data"]
164164
audio = np.array(audio, dtype=np.float32)
165-
if args.model_name == "cosyvoice2":
166-
sample_rate = 24000
167-
else:
165+
if args.model_name == "spark_tts":
168166
sample_rate = 16000
167+
else:
168+
sample_rate = 24000
169169
sf.write(args.output_audio, audio, sample_rate, "PCM_16")

runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import numpy as np
3434
import s3tokenizer
3535

36+
ORIGINAL_VOCAB_SIZE = 151663
3637

3738
class TritonPythonModel:
3839
"""Triton Python model for audio tokenization.
@@ -81,7 +82,7 @@ def execute(self, requests):
8182

8283
mels, mels_lens = s3tokenizer.padding(mels)
8384
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
84-
codes = codes.clone() + 151663
85+
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
8586

8687
responses = []
8788
for i in range(len(requests)):

runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,6 @@ def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_fe
199199
Returns:
200200
Generated waveform tensor
201201
"""
202-
print(prompt_speech_tokens.shape, prompt_speech_feat.shape, prompt_spk_embedding.shape, target_speech_tokens.shape)
203-
# Convert tensors to Triton format
204202
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
205203
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
206204
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
@@ -228,9 +226,7 @@ def parse_input(self, text, prompt_text, prompt_speech_tokens):
228226
prompt = self.prompt_template.format(input_text=total_text)
229227
input_ids = self.tokenizer.encode(prompt)
230228
input_ids = torch.tensor([input_ids], dtype=torch.int32)
231-
print(input_ids.shape, "before cat")
232229
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
233-
print(input_ids.shape, "after cat", prompt_speech_tokens.shape)
234230
return input_ids
235231

236232
def _extract_spk_embedding(self, speech):
@@ -271,23 +267,15 @@ def execute(self, requests):
271267
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
272268
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
273269

274-
# TODO: FIX ME
270+
275271
wav_tensor = wav.as_numpy()
276-
print(wav_tensor.shape, "wav_tensor")
277272
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
278-
print(wav_tensor.shape, "wav_tensor after")
279273
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
280274
speech_feat = self._extract_speech_feat(prompt_speech_resample)
281-
print(speech_feat.shape, "speech_feat")
282-
print(prompt_speech_tokens.shape, "prompt_speech_tokens here")
283275
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
284276
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
285277
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
286-
print(prompt_speech_tokens.shape, "prompt_speech_tokens after")
287-
print(speech_feat.shape, "speech_feat after")
288-
print(token_len, "token_len")
289278

290-
# Extract text inputs
291279
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
292280
reference_text = reference_text[0][0].decode('utf-8')
293281

runtime/triton_trtllm/model_repo/token2wav/1/model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,11 @@
3838
from hyperpyyaml import load_hyperpyyaml
3939
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
4040
from cosyvoice.utils.common import TrtContextWrapper
41-
#import sys
42-
#sys.path.append("/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS")
4341

44-
# Configure logging
4542
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
4643
logger = logging.getLogger(__name__)
4744

45+
ORIGINAL_VOCAB_SIZE = 151663
4846

4947
class CosyVoice2:
5048

@@ -162,8 +160,9 @@ def execute(self, requests):
162160
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
163161
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
164162

165-
prompt_speech_tokens = prompt_speech_tokens - 151663
166-
target_speech_tokens = target_speech_tokens - 151663
163+
# shift the speech tokens according to the original vocab size
164+
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
165+
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
167166

168167
tts_mel, _ = self.token2wav_model.model.flow.inference(
169168
token=target_speech_tokens,

runtime/triton_trtllm/run.sh

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
# huggingface-cli download --local-dir cosyvoice2_llm yuekai/cosyvoice2_llm
2-
# modelscope download --model iic/CosyVoice2-0.5B --local_dir ./CosyVoice2-0.5B/
3-
# git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
4-
# cd CosyVoice
5-
# git submodule update --init --recursive
1+
62
export CUDA_VISIBLE_DEVICES=0
73
export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice:$PYTHONPATH
84
export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS:$PYTHONPATH
@@ -12,11 +8,21 @@ stop_stage=$2
128
huggingface_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/cosyvoice2_llm
139
model_scope_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice2-0.5B
1410
trt_dtype=bfloat16
15-
trt_dtype=float16
1611
trt_weights_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_weights_${trt_dtype}
1712
trt_engines_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_engines_${trt_dtype}
1813

1914
model_repo=./model_repo_cosyvoice2
15+
16+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
17+
echo " "
18+
huggingface-cli download --local-dir cosyvoice2_llm yuekai/cosyvoice2_llm
19+
modelscope download --model iic/CosyVoice2-0.5B --local_dir ./CosyVoice2-0.5B/
20+
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
21+
cd CosyVoice
22+
git submodule update --init --recursive
23+
fi
24+
25+
2026
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
2127
echo "Converting checkpoint to TensorRT weights"
2228
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \

0 commit comments

Comments
 (0)