Skip to content

Commit de44b95

Browse files
authored
add StabilityAudio API nodes (#9749)
1 parent 543888d commit de44b95

File tree

4 files changed

+415
-4
lines changed

4 files changed

+415
-4
lines changed

comfy_api_nodes/apinode_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi(
518518
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
519519

520520

521+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
522+
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
523+
if wav.dtype.is_floating_point:
524+
return wav
525+
elif wav.dtype == torch.int16:
526+
return wav.float() / (2 ** 15)
527+
elif wav.dtype == torch.int32:
528+
return wav.float() / (2 ** 31)
529+
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
530+
531+
532+
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
533+
"""
534+
Decode any common audio container from bytes using PyAV and return
535+
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
536+
"""
537+
with av.open(io.BytesIO(audio_bytes)) as af:
538+
if not af.streams.audio:
539+
raise ValueError("No audio stream found in response.")
540+
stream = af.streams.audio[0]
541+
542+
in_sr = int(stream.codec_context.sample_rate)
543+
out_sr = in_sr
544+
545+
frames: list[torch.Tensor] = []
546+
n_channels = stream.channels or 1
547+
548+
for frame in af.decode(streams=stream.index):
549+
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
550+
buf = torch.from_numpy(arr)
551+
if buf.ndim == 1:
552+
buf = buf.unsqueeze(0) # [T] -> [1, T]
553+
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
554+
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
555+
elif buf.shape[0] != n_channels:
556+
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
557+
frames.append(buf)
558+
559+
if not frames:
560+
raise ValueError("Decoded zero audio frames.")
561+
562+
wav = torch.cat(frames, dim=1) # [C, T]
563+
wav = f32_pcm(wav)
564+
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
565+
566+
567+
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
568+
waveform = audio["waveform"].cpu()
569+
570+
output_buffer = io.BytesIO()
571+
output_container = av.open(output_buffer, mode='w', format="mp3")
572+
573+
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
574+
out_stream.bit_rate = 320000
575+
576+
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
577+
frame.sample_rate = audio["sample_rate"]
578+
frame.pts = 0
579+
output_container.mux(out_stream.encode(frame))
580+
output_container.mux(out_stream.encode(None))
581+
output_container.close()
582+
output_buffer.seek(0)
583+
return output_buffer
584+
585+
521586
def audio_to_base64_string(
522587
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
523588
) -> str:

comfy_api_nodes/apis/stability_api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel):
125125

126126
class StabilityAsyncResponse(BaseModel):
127127
id: Optional[str] = Field(None)
128+
129+
130+
class StabilityTextToAudioRequest(BaseModel):
131+
model: str = Field(...)
132+
prompt: str = Field(...)
133+
duration: int = Field(190, ge=1, le=190)
134+
seed: int = Field(0, ge=0, le=4294967294)
135+
steps: int = Field(8, ge=4, le=8)
136+
output_format: str = Field("wav")
137+
138+
139+
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
140+
strength: float = Field(0.01, ge=0.01, le=1.0)
141+
142+
143+
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
144+
mask_start: int = Field(30, ge=0, le=190)
145+
mask_end: int = Field(190, ge=0, le=190)
146+
147+
148+
class StabilityAudioResponse(BaseModel):
149+
audio: Optional[str] = Field(None)

0 commit comments

Comments
 (0)