Skip to content

Commit

Permalink
Merge branch 'dev' into supprot_aysnc_vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyizhu authored Sep 9, 2024
2 parents d1a117b + 8fcc0cd commit b839200
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _load(
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
device=self.device,
device=device,
)
.to(device)
.eval()
Expand All @@ -290,8 +290,8 @@ def _load(
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.from_pretrained(embed_path, device=self.device)
self.embed = embed.to(self.device)
embed.from_pretrained(embed_path, device=device)
self.embed = embed.to(device)
self.logger.log(logging.INFO, "embed loaded.")

gpt = GPT(
Expand Down Expand Up @@ -319,6 +319,7 @@ def _load(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
device=device,
)
.to(device)
.eval()
Expand Down
4 changes: 2 additions & 2 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
hop_length=256,
n_mels=100,
padding: Literal["center", "same"] = "center",
device: torch.device = torch.device("cuda"),
device: torch.device = torch.device("cpu"),
):
super().__init__()
self.device = device
Expand Down Expand Up @@ -213,7 +213,7 @@ def __init__(
vq_config: Optional[dict] = None,
dim=512,
coef: Optional[str] = None,
device: torch.device = torch.device("cuda"),
device: torch.device = torch.device("cpu"),
):
super().__init__()
if coef is None:
Expand Down
2 changes: 1 addition & 1 deletion examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@
"metadata": {},
"outputs": [],
"source": [
"from tools.audio import load_audio\n",
"from ChatTTS.tools.audio import load_audio\n",
"\n",
"spk_smp = chat.sample_audio_speaker(load_audio(\"sample.mp3\", 24000))\n",
"print(spk_smp) # save it in order to load the speaker without sample audio next time\n",
Expand Down

0 comments on commit b839200

Please sign in to comment.