Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 89 additions & 6 deletions modules/chatterbox_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,44 @@
CHATTERBOX_FILES_TO_DOWNLOAD = ["ve.pt", "t3_cfg.pt", "s3gen.pt", "tokenizer.json", "conds.pt"]
DEFAULT_MODEL_PACK_NAME = "resembleai_default_voice"

def clear_gpu_cache():
"""Clear GPU cache for all available devices."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()

def unload_chatterbox_tts_model(model_pack_name, device_str="cuda"):
"""Unloads a specific TTS model from cache."""
cache_key = (model_pack_name, device_str, "tts")
if cache_key in TTS_MODEL_CACHE:
print(f"ChatterboxTTS: Unloading TTS model '{model_pack_name}' from '{device_str}' cache.")
del TTS_MODEL_CACHE[cache_key]
clear_gpu_cache()
return True
return False

def unload_chatterbox_vc_model(model_pack_name, device_str="cuda"):
"""Unloads a specific VC model from cache."""
cache_key = (model_pack_name, device_str, "vc")
if cache_key in VC_MODEL_CACHE:
print(f"ChatterboxVC: Unloading VC model '{model_pack_name}' from '{device_str}' cache.")
del VC_MODEL_CACHE[cache_key]
clear_gpu_cache()
return True
return False

def unload_all_chatterbox_models():
"""Unloads all cached Chatterbox models."""
tts_count = len(TTS_MODEL_CACHE)
vc_count = len(VC_MODEL_CACHE)

TTS_MODEL_CACHE.clear()
VC_MODEL_CACHE.clear()
clear_gpu_cache()

print(f"ChatterboxTTS/VC: Unloaded {tts_count} TTS models and {vc_count} VC models from cache.")

def get_chatterbox_model_pack_names():
"""Returns a list of available Chatterbox model pack names (subdirectories)."""
chatterbox_models_base_path = os.path.join(folder_paths.models_dir, CHATTERBOX_MODEL_SUBDIR)
Expand Down Expand Up @@ -115,7 +153,7 @@ def load_chatterbox_tts_model(model_pack_name, device_str="cuda"):
raise
return model

def get_cached_chatterbox_tts_model(model_pack_name, device_str="cuda"):
def get_cached_chatterbox_tts_model(model_pack_name, device_str="cuda", keep_model_loaded=True):
"""Loads and caches the ChatterboxTTS model."""
if not model_pack_name:
available_packs = get_chatterbox_model_pack_names()
Expand All @@ -140,8 +178,30 @@ def get_cached_chatterbox_tts_model(model_pack_name, device_str="cuda"):
else:
print(f"ChatterboxTTS: TTS Model for '{model_pack_name}' on '{device_str}' not in cache. Loading...")

TTS_MODEL_CACHE[cache_key] = load_chatterbox_tts_model(model_pack_name, device_str)
return TTS_MODEL_CACHE[cache_key]
model = load_chatterbox_tts_model(model_pack_name, device_str)

# Only cache the model if keep_model_loaded is True
if keep_model_loaded:
TTS_MODEL_CACHE[cache_key] = model
print(f"ChatterboxTTS: Model '{model_pack_name}' cached for reuse.")
else:
print(f"ChatterboxTTS: Model '{model_pack_name}' loaded but not cached (keep_model_loaded=False).")

return model

def get_cached_chatterbox_tts_model_with_fallback(model_pack_name, device_str="cuda", keep_model_loaded=True):
"""Loads and caches the ChatterboxTTS model with automatic fallback to CPU on GPU errors."""
try:
return get_cached_chatterbox_tts_model(model_pack_name, device_str, keep_model_loaded)
except RuntimeError as e:
error_str = str(e)
if ("CUDA" in error_str or "MPS" in error_str) and device_str != "cpu":
print(f"ChatterboxTTS: GPU error detected, falling back to CPU: {e}")
# Unload any existing model on the failed device
unload_chatterbox_tts_model(model_pack_name, device_str)
return get_cached_chatterbox_tts_model(model_pack_name, "cpu", keep_model_loaded)
else:
raise


def load_chatterbox_vc_model(model_pack_name, device_str="cuda"):
Expand Down Expand Up @@ -174,7 +234,7 @@ def load_chatterbox_vc_model(model_pack_name, device_str="cuda"):
raise
return model

def get_cached_chatterbox_vc_model(model_pack_name, device_str="cuda"):
def get_cached_chatterbox_vc_model(model_pack_name, device_str="cuda", keep_model_loaded=True):
"""Loads and caches the ChatterboxVC model."""
if not model_pack_name:
available_packs = get_chatterbox_model_pack_names()
Expand All @@ -198,8 +258,31 @@ def get_cached_chatterbox_vc_model(model_pack_name, device_str="cuda"):
print(f"ChatterboxVC: Device mismatch for cached VC model '{model_pack_name}'. Reloading.")
else:
print(f"ChatterboxVC: VC Model for '{model_pack_name}' on '{device_str}' not in cache. Loading...")
VC_MODEL_CACHE[cache_key] = load_chatterbox_vc_model(model_pack_name, device_str)
return VC_MODEL_CACHE[cache_key]

model = load_chatterbox_vc_model(model_pack_name, device_str)

# Only cache the model if keep_model_loaded is True
if keep_model_loaded:
VC_MODEL_CACHE[cache_key] = model
print(f"ChatterboxVC: Model '{model_pack_name}' cached for reuse.")
else:
print(f"ChatterboxVC: Model '{model_pack_name}' loaded but not cached (keep_model_loaded=False).")

return model

def get_cached_chatterbox_vc_model_with_fallback(model_pack_name, device_str="cuda", keep_model_loaded=True):
"""Loads and caches the ChatterboxVC model with automatic fallback to CPU on GPU errors."""
try:
return get_cached_chatterbox_vc_model(model_pack_name, device_str, keep_model_loaded)
except RuntimeError as e:
error_str = str(e)
if ("CUDA" in error_str or "MPS" in error_str) and device_str != "cpu":
print(f"ChatterboxVC: GPU error detected, falling back to CPU: {e}")
# Unload any existing model on the failed device
unload_chatterbox_vc_model(model_pack_name, device_str)
return get_cached_chatterbox_vc_model(model_pack_name, "cpu", keep_model_loaded)
else:
raise


def set_chatterbox_seed(seed: int):
Expand Down
25 changes: 19 additions & 6 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from .modules.chatterbox_handler import (
get_chatterbox_model_pack_names,
get_cached_chatterbox_tts_model,
get_cached_chatterbox_vc_model,
get_cached_chatterbox_tts_model_with_fallback,
get_cached_chatterbox_vc_model_with_fallback,
set_chatterbox_seed,
unload_chatterbox_tts_model,
unload_chatterbox_vc_model,
clear_gpu_cache,
CHATTERBOX_MODEL_SUBDIR,
DEFAULT_MODEL_PACK_NAME
)
Expand All @@ -34,6 +37,7 @@ def INPUT_TYPES(cls):
},
"optional": {
"audio_prompt": ("AUDIO",),
"keep_model_loaded": ("BOOLEAN", {"default": False}),
}
}

Expand All @@ -43,15 +47,15 @@ def INPUT_TYPES(cls):
CATEGORY = "audio/generation"
OUTPUT_NODE = True

def synthesize(self, model_pack_name, text, exaggeration, temperature, cfg_weight, seed, device, audio_prompt=None):
def synthesize(self, model_pack_name, text, exaggeration, temperature, cfg_weight, seed, device, audio_prompt=None, keep_model_loaded=False):
if not text.strip():
#print("Chatterbox TTS: Empty text provided, returning silent audio.")
dummy_sr = 24000
silent_waveform = torch.zeros((1, dummy_sr), dtype=torch.float32, device="cpu")
return ({"waveform": silent_waveform.unsqueeze(0), "sample_rate": dummy_sr},)

try:
chatterbox_model = get_cached_chatterbox_tts_model(model_pack_name, device_str=device)
chatterbox_model = get_cached_chatterbox_tts_model_with_fallback(model_pack_name, device_str=device, keep_model_loaded=keep_model_loaded)
except Exception as e:
print(f"ChatterboxTTS: Error loading/downloading TTS model pack '{model_pack_name}': {e}")
dummy_sr = 24000
Expand Down Expand Up @@ -105,6 +109,10 @@ def synthesize(self, model_pack_name, text, exaggeration, temperature, cfg_weigh
os.remove(audio_prompt_path_temp)
except Exception as e:
print(f"ChatterboxTTS: Error removing temp audio prompt file {audio_prompt_path_temp}: {e}")

# Unload model if keep_model_loaded is False
if not keep_model_loaded:
unload_chatterbox_tts_model(model_pack_name, device_str=device)

wav_tensor_comfy = wav_tensor_chatterbox.cpu().unsqueeze(0)
return ({"waveform": wav_tensor_comfy, "sample_rate": chatterbox_model.sr},)
Expand All @@ -126,6 +134,7 @@ def INPUT_TYPES(cls):
},
"optional": {
"target_voice_audio": ("AUDIO",), # Optional: if not provided, uses default voice from conds.pt
"keep_model_loaded": ("BOOLEAN", {"default": False}),
}
}

Expand Down Expand Up @@ -171,15 +180,15 @@ def _save_audio_to_temp_file(self, audio_data, prefix=""):
os.remove(temp_file_path)
return None

def convert_voice(self, model_pack_name, source_audio, device, target_voice_audio=None):
def convert_voice(self, model_pack_name, source_audio, device, target_voice_audio=None, keep_model_loaded=False):
if source_audio is None or source_audio.get("waveform") is None or source_audio["waveform"].numel() == 0:
print("ChatterboxVC: No source audio provided, returning silent audio.")
dummy_sr = 24000
silent_waveform = torch.zeros((1, dummy_sr), dtype=torch.float32, device="cpu")
return ({"waveform": silent_waveform.unsqueeze(0), "sample_rate": dummy_sr},)

try:
vc_model = get_cached_chatterbox_vc_model(model_pack_name, device_str=device)
vc_model = get_cached_chatterbox_vc_model_with_fallback(model_pack_name, device_str=device, keep_model_loaded=keep_model_loaded)
except Exception as e:
print(f"ChatterboxVC: Error loading/downloading VC model pack '{model_pack_name}': {e}")
dummy_sr = 24000
Expand Down Expand Up @@ -224,6 +233,10 @@ def convert_voice(self, model_pack_name, source_audio, device, target_voice_audi
os.remove(target_voice_path_temp)
except Exception as e:
print(f"ChatterboxVC: Error removing temp target audio file {target_voice_path_temp}: {e}")

# Unload model if keep_model_loaded is False
if not keep_model_loaded:
unload_chatterbox_vc_model(model_pack_name, device_str=device)

# ComfyUI AUDIO format: {"waveform": tensor (B, C, T), "sample_rate": int}
# ChatterboxVC output: tensor (1, T)
Expand Down