Skip to content

Commit

Permalink
feat:support async vllm generator
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyizhu committed Sep 12, 2024
1 parent 8fcc0cd commit 1e2b671
Show file tree
Hide file tree
Showing 11 changed files with 800 additions and 266 deletions.
173 changes: 89 additions & 84 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import logging
import tempfile
import uuid
from dataclasses import dataclass, asdict
from typing import Literal, Optional, List, Tuple, Dict, Union
from json import load
Expand Down Expand Up @@ -200,9 +201,10 @@ def infer(
do_homophone_replacement=True,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
stream_batch_size=16,
):
self.context.set(False)
res_gen = self._infer(
return self._infer(
text,
stream,
lang,
Expand All @@ -213,11 +215,8 @@ def infer(
do_homophone_replacement,
params_refine_text,
params_infer_code,
stream_batch_size,
)
if stream:
return res_gen
else:
return next(res_gen)

def interrupt(self):
self.context.set(True)
Expand Down Expand Up @@ -339,7 +338,7 @@ def _load(

return self.has_loaded()

def _infer(
async def _infer(
self,
text,
stream=False,
Expand All @@ -351,6 +350,7 @@ def _infer(
do_homophone_replacement=True,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
stream_batch_size=16,
):

assert self.has_loaded(use_decoder=use_decoder)
Expand Down Expand Up @@ -384,41 +384,38 @@ def _infer(
yield text
return

if stream:
length = 0
pass_batch_count = 0
for result in self._infer_code(
length = 0
async for result in self._infer_code(
text,
stream,
self.device,
use_decoder,
params_infer_code,
stream_batch_size,
):
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
if b > wavs.shape[1]:
b = wavs.shape[1]
new_wavs = wavs[:, a:b]
length = b
yield new_wavs
if result.finished:
yield wavs[:, length:]
else:
yield wavs
if stream:
new_wavs = wavs[:, length:]
# Identify rows with non-zero elements using np.any
# keep_rows = np.any(array != 0, axis=1)
keep_cols = np.sum(new_wavs != 0, axis=0) > 0
# Filter both rows and columns using slicing
yield new_wavs[:][:, keep_cols]
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
keep_cols = np.sum(abs(wavs[0][length:]) > 1e-6, axis=0) > 0

import librosa
silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
silence_left = 0
if len(silence_intervals) == 0:
silence_left = len(wavs[0])
else:
for i in range(len(silence_intervals)):
silence_left = silence_intervals[i][0]
if silence_left <= 0:
continue
new_wavs = wavs[:, length : length + silence_left]
length += len(new_wavs[0])
yield new_wavs

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
Expand Down Expand Up @@ -457,13 +454,14 @@ def _decode_to_wavs(
return wavs

@torch.no_grad()
def _infer_code(
async def _infer_code(
self,
text: Tuple[List[str], str],
stream: bool,
device: torch.device,
return_hidden: bool,
params: InferCodeParams,
stream_batch_size: int,
):

gpt = self.gpt
Expand Down Expand Up @@ -504,6 +502,17 @@ def _infer_code(
repetition_penalty=params.repetition_penalty,
)

speaker_embedding_param = self.embed(input_ids, text_mask)
del text_mask
if params.spk_emb is not None:
self.speaker.apply(
speaker_embedding_param,
params.spk_emb,
input_ids,
self.tokenizer.spk_emb_ids,
self.gpt.device_gpt,
)

if gpt.is_vllm:
from .model.velocity import SamplingParams

Expand All @@ -522,62 +531,58 @@ def _infer_code(
result = gpt.llm.generate(
None,
sample_params,
input_ids,
uuid.uuid4(),
speaker_embedding_param,
input_ids[0]
)

token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)

del text_mask, input_ids

return [
GPT.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states,
attentions=[],
),
]

emb = self.embed(input_ids, text_mask)

del text_mask

if params.spk_emb is not None:
self.speaker.apply(
emb,
params.spk_emb,
async for i in result:
token_ids = []
hidden_states = []
if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)
yield GPT.GenerationOutputs(
ids=token_ids,
finished=i.finished,
hiddens=hidden_states,
attentions=[],
)
else:
result = gpt.generate(
speaker_embedding_param,
input_ids,
self.tokenizer.spk_emb_ids,
self.gpt.device_gpt,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
manual_seed=params.manual_seed,
context=self.context,
)

result = gpt.generate(
emb,
input_ids,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
manual_seed=params.manual_seed,
context=self.context,
)

del emb, input_ids

return result
del speaker_embedding_param, input_ids
async for i in result:
token_ids = []
hidden_states = []
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
token_ids.append(i.ids[0])
hidden_states.append(
i.hiddens[0].to(torch.float32).to(self.device)
)
yield GPT.GenerationOutputs(
ids=token_ids,
finished=i.finished,
hiddens=hidden_states,
attentions=[],
)

@torch.no_grad()
def _refine_text(
Expand Down
8 changes: 7 additions & 1 deletion ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def from_pretrained(
num_audio_tokens=self.num_audio_tokens,
num_text_tokens=self.num_text_tokens,
post_model_path=embed_file_path,
dtype="float32"
)
self.logger.info("vLLM model loaded")
return
Expand Down Expand Up @@ -273,6 +274,7 @@ class GenerationOutputs:
ids: List[torch.Tensor]
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
hiddens: List[torch.Tensor]
finished: bool

def destroy(self):
del_all(self.ids)
Expand All @@ -288,6 +290,7 @@ def _prepare_generation_outputs(
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
hiddens: List[torch.Tensor],
infer_text: bool,
finished: bool,
) -> GenerationOutputs:
inputs_ids = [
inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
Expand All @@ -305,10 +308,11 @@ def _prepare_generation_outputs(
ids=inputs_ids,
attentions=attentions,
hiddens=hiddens,
finished=finished,
)

@torch.no_grad()
def generate(
async def generate(
self,
emb: torch.Tensor,
inputs_ids: torch.Tensor,
Expand Down Expand Up @@ -581,6 +585,7 @@ def generate(
attentions,
hiddens,
infer_text,
False
)
del not_finished

Expand Down Expand Up @@ -610,4 +615,5 @@ def generate(
attentions,
hiddens,
infer_text,
True
)
Loading

0 comments on commit 1e2b671

Please sign in to comment.