diff --git a/src/scope/core/nodes/__init__.py b/src/scope/core/nodes/__init__.py index 993af53b1..22fe9c951 100644 --- a/src/scope/core/nodes/__init__.py +++ b/src/scope/core/nodes/__init__.py @@ -11,16 +11,18 @@ """ from .base import BaseNode, Node, NodeDefinition, NodeParam, NodePort, Requirements -from .builtins import SchedulerNode +from .builtins import AudioSourceNode, SchedulerNode from .registry import NodeRegistry def register_builtin_nodes() -> None: """Register all built-in node types shipped with the foundation.""" NodeRegistry.register(SchedulerNode) + NodeRegistry.register(AudioSourceNode) __all__ = [ + "AudioSourceNode", "BaseNode", "Node", "NodeDefinition", diff --git a/src/scope/core/nodes/builtins/__init__.py b/src/scope/core/nodes/builtins/__init__.py index eb5d79632..eec3bf677 100644 --- a/src/scope/core/nodes/builtins/__init__.py +++ b/src/scope/core/nodes/builtins/__init__.py @@ -1,5 +1,6 @@ """Built-in nodes shipped with the foundation abstraction.""" +from .audio_io import AudioSourceNode from .scheduler import SchedulerNode -__all__ = ["SchedulerNode"] +__all__ = ["AudioSourceNode", "SchedulerNode"] diff --git a/src/scope/core/nodes/builtins/audio_io.py b/src/scope/core/nodes/builtins/audio_io.py new file mode 100644 index 000000000..4a4e4da84 --- /dev/null +++ b/src/scope/core/nodes/builtins/audio_io.py @@ -0,0 +1,222 @@ +"""Built-in audio I/O nodes: AudioSource (load a WAV file once). + +Terminal audio output is handled by the regular Sink node: audio edges +into a Sink are routed straight to the WebRTC audio track via the +session's audio_output_queue, with no intermediate node needed. +""" + +from __future__ import annotations + +import logging +import os +import struct +from pathlib import Path +from typing import Any, ClassVar + +import numpy as np +import torch + +from ..base import BaseNode, NodeDefinition, NodeParam, NodePort + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 48000 + + +def _read_wav_float32(path: str) -> tuple[np.ndarray, int]: + """Parse a WAV file into float32 samples without the stdlib ``wave`` + module, which rejects IEEE-float (format 3) files. + + Returns (data, sample_rate) where ``data`` has shape (samples, channels). + Supports formats 1 (PCM int) and 3 (IEEE float) — the two common cases. + WAVE_FORMAT_EXTENSIBLE (0xFFFE) is unwrapped to its underlying format. + """ + with open(path, "rb") as f: + header = f.read(12) + if len(header) < 12 or header[:4] != b"RIFF" or header[8:12] != b"WAVE": + raise ValueError(f"Not a WAV file: {path}") + + fmt_code: int | None = None + n_channels = 1 + sample_rate = 0 + bits_per_sample = 0 + pcm_bytes = b"" + + while True: + chunk_header = f.read(8) + if len(chunk_header) < 8: + break + chunk_id, chunk_size = struct.unpack("<4sI", chunk_header) + chunk_data = f.read(chunk_size) + if chunk_size % 2 == 1: + f.read(1) # RIFF pads odd-sized chunks + + if chunk_id == b"fmt " and len(chunk_data) >= 16: + ( + fmt_code, + n_channels, + sample_rate, + _byte_rate, + _block_align, + bits_per_sample, + ) = struct.unpack("= 26: + fmt_code = struct.unpack(" NodeDefinition: + return NodeDefinition( + node_type_id=cls.node_type_id, + display_name="Audio Source", + category="audio", + description="Load audio from a WAV file at 48kHz stereo.", + # continuous=False so NodeProcessor only re-runs us when a + # parameter actually changes; otherwise the worker would call + # execute() every tick and either flood the graph (on success) + # or flood the log (on missing-file). + continuous=False, + inputs=[], + outputs=[ + NodePort(name="audio", port_type="audio", description="Audio waveform"), + ], + params=[ + NodeParam( + name="file_id", + param_type="string", + default="", + description="Audio file path", + ), + NodeParam( + name="duration", + param_type="number", + default=15.0, + description="Duration (s)", + ui={"min": 1, "max": 600, "step": 1}, + ), + ], + ) + + def _load_audio(self, file_path: str, duration: float) -> None: + """Load, decode, resample to 48kHz stereo, and clip to duration.""" + data, sr = _read_wav_float32(file_path) # (samples, channels) + + if data.shape[1] == 1: + data = np.concatenate([data, data], axis=1) + elif data.shape[1] > 2: + data = data[:, :2] + data = data.T # (channels, samples) + + if sr != SAMPLE_RATE and sr > 0: + num_samples = data.shape[1] + new_len = int(num_samples * SAMPLE_RATE / sr) + old_indices = np.linspace(0, num_samples - 1, new_len) + resampled = np.zeros((data.shape[0], new_len), dtype=np.float32) + for ch in range(data.shape[0]): + resampled[ch] = np.interp(old_indices, np.arange(num_samples), data[ch]) + data = resampled + + max_samples = int(duration * SAMPLE_RATE) + if data.shape[1] > max_samples: + data = data[:, :max_samples] + + self._audio_data = data + self._loaded_file = file_path + self._loaded_duration = duration + logger.info( + "AudioSource loaded: %s (%.1fs)", + file_path, + data.shape[1] / SAMPLE_RATE, + ) + + def execute(self, inputs: dict[str, Any], **kwargs) -> dict[str, Any]: + file_id = kwargs.get("file_id", "") + if not file_id: + return {} + resolved = self._resolve_path(file_id) + if not resolved: + return {} + + duration = float(kwargs.get("duration", 15.0)) + # Cache key includes duration: a duration change must re-trim + # (or re-decode if duration grows past the current clip). + if resolved != self._loaded_file or duration != self._loaded_duration: + try: + self._load_audio(resolved, duration) + except Exception as e: + logger.error("AudioSourceNode failed to load %s: %s", resolved, e) + return {} + + if self._audio_data is None or self._audio_data.shape[1] == 0: + return {} + return {"audio": (torch.from_numpy(self._audio_data), SAMPLE_RATE)} + + @staticmethod + def _resolve_path(file_id: str) -> str | None: + """Resolve a file path; falls back to ``~/.daydream-scope/assets``.""" + if os.path.exists(file_id): + return os.path.abspath(file_id) + candidate = Path.home() / ".daydream-scope" / "assets" / file_id + if candidate.exists(): + return str(candidate) + logger.warning("AudioSource: file not found: %s", file_id) + return None diff --git a/src/scope/core/nodes/processor.py b/src/scope/core/nodes/processor.py index 68c970b10..649177f0e 100644 --- a/src/scope/core/nodes/processor.py +++ b/src/scope/core/nodes/processor.py @@ -43,10 +43,18 @@ def __init__( definition = node.get_definition() + # Output ports wired straight to a sink; populated by graph_executor. + # Values on these ports are routed to ``audio_output_queue`` for + # FrameProcessor.get_audio_packet() to drain. + self.audio_sink_ports: set[str] = set() + # Parameter names this node declares — used to ignore broadcast + # updates aimed at other nodes. + self._declared_param_names: set[str] = {p.name for p in definition.params} + # Consumed by FrameProcessor.get_audio_packet() on the sink feeder. - # Kept here (even without a routing implementation) so a NodeProcessor - # can stand in as the sink's feeder without crashing the audio path. - self.audio_output_queue: queue.Queue = queue.Queue(maxsize=10) + # maxsize=1 + blocking put (see _route_audio) gives backpressure so + # batch decoders can't outrun real-time playback. + self.audio_output_queue: queue.Queue = queue.Queue(maxsize=1) self.worker_thread: threading.Thread | None = None self.shutdown_event = threading.Event() @@ -56,6 +64,10 @@ def __init__( self._source_executed = False self._has_executed = False self._continuous = definition.continuous + # Latch of last-seen inputs per port, so static upstreams (one-shot + # model/vae/clip handles) survive across param-triggered re-runs. + self._last_inputs: dict[str, Any] = {} + self._needs_rerun = False # PipelineProcessor interface compatibility: graph_executor populates # this for every processor; kept as an empty dict so that write is safe. @@ -91,7 +103,15 @@ def stop(self) -> None: logger.info("NodeProcessor stopped: %s", self.node_id) def update_parameters(self, parameters: dict[str, Any]) -> None: + # FrameProcessor broadcasts node-less updates to every processor; + # only mark ourselves dirty when a value we actually declare changes. + changed = any( + key in self._declared_param_names and self.parameters.get(key) != value + for key, value in parameters.items() + ) self.parameters.update(parameters) + if changed: + self._needs_rerun = True def set_beat_cache_reset_rate(self, rate): # PipelineProcessor compat pass @@ -124,31 +144,40 @@ def _process_once(self) -> None: is_source_node = not all_queues # Source nodes execute once; continuous=True nodes re-execute every - # tick (for streaming I/O). - if is_source_node and self._source_executed and not self._continuous: + # tick (for streaming I/O). A pending parameter change also re-wakes. + if ( + is_source_node + and self._source_executed + and not self._continuous + and not self._needs_rerun + ): self.shutdown_event.wait(1.0) return - # Gather inputs. Continuous nodes consume whatever's available - # (empty inputs stay absent). Non-continuous nodes wait until every - # input queue has data, so they execute with a complete input set. + # Drain fresh values into the latch cache; ports whose upstream has + # already gone quiet (e.g. one-shot model handles) replay from cache. + fresh: dict[str, Any] = {} inputs: dict[str, Any] = {} if all_queues: - if self._continuous: - for port_name, q in all_queues.items(): - try: - inputs[port_name] = q.get_nowait() - except queue.Empty: - pass - else: - if any(q.empty() for q in all_queues.values()): - self.shutdown_event.wait(SLEEP_TIME) - return - inputs = {name: q.get_nowait() for name, q in all_queues.items()} - - # Non-continuous nodes skip re-execution when no new inputs arrived - # and they already have a cached output. - if self._has_executed and not inputs and not self._continuous: + for port_name, q in all_queues.items(): + try: + fresh[port_name] = q.get_nowait() + except queue.Empty: + pass + self._last_inputs.update(fresh) + inputs = dict(self._last_inputs) + # First run: wait until every port has been seen at least once. + if not self._has_executed and set(all_queues.keys()) - inputs.keys(): + self.shutdown_event.wait(SLEEP_TIME) + return + + # Non-continuous nodes skip when nothing changed since last run. + if ( + self._has_executed + and not self._continuous + and not fresh + and not self._needs_rerun + ): self.shutdown_event.wait(SLEEP_TIME) return @@ -156,6 +185,7 @@ def _process_once(self) -> None: if is_source_node: self._source_executed = True + self._needs_rerun = False if not outputs: self.shutdown_event.wait(SLEEP_TIME) @@ -169,6 +199,10 @@ def _route_outputs(self, outputs: dict[str, Any]) -> None: if value is None: continue + # Sink-bound audio also goes to audio_output_queue for WebRTC. + if port_name in self.audio_sink_ports: + self._route_audio(value) + # Fan out to all downstream queues on this port. Block briefly # when queues are full so producers throttle to consumer pace # and GPU tensors don't pile up in memory. @@ -181,3 +215,21 @@ def _route_outputs(self, outputs: dict[str, Any]) -> None: break except queue.Full: continue + + def _route_audio(self, value: Any) -> None: + """Convert a node output value and push it to audio_output_queue.""" + # Lazy import keeps ``scope.core`` from reaching back into + # ``scope.server`` at module load (disallowed by the project layout). + from scope.server.media_packets import audio_packet_from_node_output + + packet = audio_packet_from_node_output(value) + if packet is None: + return + # Blocking put with retry: stalls the worker when the audio track + # hasn't drained the previous chunk — this is the backpressure. + while not self.shutdown_event.is_set(): + try: + self.audio_output_queue.put(packet, timeout=0.1) + break + except queue.Full: + continue diff --git a/src/scope/server/audio_track.py b/src/scope/server/audio_track.py index 0568e477f..81730a2cd 100644 --- a/src/scope/server/audio_track.py +++ b/src/scope/server/audio_track.py @@ -25,6 +25,25 @@ AUDIO_MAX_BUFFER_SAMPLES = AUDIO_CLOCK_RATE * 60 +# Playhead handle for graph nodes that need the current audio position +# (e.g. DEMON's StreamVAEDecode skip gate, which mirrors the realtime +# demo's ``audio_eng.position / SAMPLE_RATE``). Scope serves one session +# at a time, so a single optional reference is enough — each new +# AudioProcessingTrack overwrites it during __init__. +_current_track: "AudioProcessingTrack | None" = None + + +def get_current_playhead_seconds() -> float | None: + """Playhead of the live audio track in seconds, or None if none is live. + + Callers should treat None as "skip gate disabled this tick". + """ + track = _current_track + if track is None or track.readyState != "live": + return None + return track.playhead_seconds + + class AudioProcessingTrack(MediaStreamTrack): """WebRTC audio track that streams generated audio from the pipeline. @@ -56,6 +75,24 @@ def __init__( self._start: float | None = None self._timestamp: int = 0 self._last_preserved_pts: int | None = None + # Per-channel sample index where the next contiguous chunk is + # expected to begin. Used to detect and trim overlap when a + # streaming decoder (e.g. ACEStep StreamVAEDecode with + # follow_playhead) emits windows that overlap in time. None + # means "no PTS reference yet" — the next valid PTS sets it. + self._next_expected_pts: int | None = None + + global _current_track + _current_track = self + + @property + def playhead_seconds(self) -> float: + """Current playback position in seconds (monotonic recv timestamp). + + Mirrors DEMON's ``audio_eng.position / SAMPLE_RATE`` read. Value + is 0 before the first ``recv`` call. + """ + return self._timestamp / AUDIO_CLOCK_RATE @staticmethod def _resample_audio( @@ -224,8 +261,29 @@ def _drain_audio_packets(self) -> None: audio_packet.timestamp.time_base ) chunk_pts = int(round(media_ts * AUDIO_CLOCK_RATE)) + + # Overlap trim: when a streaming decoder emits windows whose + # start_sample retreats into already-buffered territory (e.g. + # ACEStep StreamVAEDecode with vae_overlap > 0 + follow_playhead), + # drop the duplicate prefix so the listener hears each sample + # exactly once. Without this the overlap region plays twice. + if ( + chunk_pts is not None + and self._next_expected_pts is not None + and chunk_pts < self._next_expected_pts + ): + overlap_per_ch = self._next_expected_pts - chunk_pts + trim_samples = overlap_per_ch * self.channels + if trim_samples >= len(interleaved): + # Entire chunk lies in the past — drop it. + continue + interleaved = interleaved[trim_samples:] + chunk_pts = self._next_expected_pts + self._chunks.append((interleaved, chunk_pts)) self._buffered_samples += len(interleaved) + if chunk_pts is not None: + self._next_expected_pts = chunk_pts + len(interleaved) // self.channels def _trim_buffer(self) -> None: # Cap buffer to prevent unbounded growth. diff --git a/src/scope/server/graph_executor.py b/src/scope/server/graph_executor.py index b6430bf00..abf0540b4 100644 --- a/src/scope/server/graph_executor.py +++ b/src/scope/server/graph_executor.py @@ -252,6 +252,22 @@ def _attach_source_output_queue(source_node_id: str, q: queue.Queue) -> None: for e in graph.edges_to(sink_id): if e.kind == "stream": feeder_proc = node_processors.get(e.from_node) + # Audio edges to sinks are served via audio_output_queue, + # not dedicated sink queues — skip queue allocation so the + # feeder isn't blocked on a queue nobody drains. + if e.from_port == "audio" or e.to_port == "audio": + if feeder_proc is not None: + sink_processors_by_node[sink_id] = feeder_proc + sink_ports = getattr(feeder_proc, "audio_sink_ports", None) + if sink_ports is not None: + sink_ports.add(e.from_port) + logger.info( + "Sink %s: audio routed from %s port '%s' via audio_output_queue", + sink_id, + e.from_node, + e.from_port, + ) + break sink_node = node_by_id[sink_id] sink_mode = sink_node.sink_mode # WebRTC preview reads sink_queues_by_node; NDI/Spout/Syphon threads diff --git a/src/scope/server/graph_schema.py b/src/scope/server/graph_schema.py index 6dbcb283e..767e3eec4 100644 --- a/src/scope/server/graph_schema.py +++ b/src/scope/server/graph_schema.py @@ -125,6 +125,27 @@ def get_record_node_ids(self) -> list[str]: """Return node ids that are record nodes.""" return [n.id for n in self.nodes if n.type == "record"] + def get_sink_modalities(self) -> tuple[bool, bool]: + """Return ``(has_video, has_audio)`` from stream edges into sinks. + + Authoritative for "what does this graph emit?" — used in place of + stale ``pipeline_ids`` declarations. Returns ``(False, False)`` when + the graph has no sinks. + """ + sink_ids = set(self.get_sink_node_ids()) + if not sink_ids: + return (False, False) + has_video = False + has_audio = False + for e in self.edges: + if e.kind != "stream" or e.to_node not in sink_ids: + continue + if e.from_port == "audio" or e.to_port == "audio": + has_audio = True + else: + has_video = True + return (has_video, has_audio) + def get_backend_node_ids(self) -> list[str]: """Return node ids that are backend (custom) nodes.""" return [n.id for n in self.nodes if n.type == "node"] diff --git a/src/scope/server/headless.py b/src/scope/server/headless.py index 3ffd357a0..451cb65a1 100644 --- a/src/scope/server/headless.py +++ b/src/scope/server/headless.py @@ -70,8 +70,9 @@ async def iter_chunks(self): class HeadlessTsStreamer(HeadlessMediaSink): """Streams headless output as MPEG-TS using PyAV.""" - def __init__(self, expect_audio: bool): + def __init__(self, expect_audio: bool, expect_video: bool = True): self._expect_audio = expect_audio + self._expect_video = expect_video self._buffer = _TsStreamBuffer() self._container = None self._video_stream = None @@ -86,12 +87,13 @@ def _init_container(self, width: int, height: int): import av self._container = av.open(self._buffer, "w", format="mpegts") - self._video_stream = self._container.add_stream( - "libx264", rate=int(RECORDING_MAX_FPS) - ) - self._video_stream.width = width + (width % 2) - self._video_stream.height = height + (height % 2) - self._video_stream.pix_fmt = "yuv420p" + if self._expect_video: + self._video_stream = self._container.add_stream( + "libx264", rate=int(RECORDING_MAX_FPS) + ) + self._video_stream.width = width + (width % 2) + self._video_stream.height = height + (height % 2) + self._video_stream.pix_fmt = "yuv420p" if self._expect_audio: self._audio_stream = self._container.add_stream( "aac", rate=AUDIO_CLOCK_RATE @@ -100,7 +102,7 @@ def _init_container(self, width: int, height: int): self._initialized = True def on_video_frame(self, video_frame) -> None: - if self._closed: + if self._closed or not self._expect_video: return import av @@ -135,7 +137,17 @@ def on_audio_chunk( import numpy as np with self._lock: - if self._closed or not self._initialized or self._audio_stream is None: + if self._closed: + return + if not self._initialized: + # Audio-only graphs (e.g. DEMON music covers) never deliver + # a video frame to bootstrap the container, so init from + # the first audio chunk. + if not self._expect_video: + self._init_container(0, 0) + else: + return + if self._audio_stream is None: return audio_np = audio_tensor.numpy() if audio_np.ndim == 1: @@ -330,11 +342,13 @@ def __init__( self, frame_processor: "FrameProcessor", expect_audio: bool = False, + expect_video: bool = True, ): from .frame_processor import FrameProcessor self.frame_processor: FrameProcessor = frame_processor self.expect_audio = expect_audio + self.expect_video = expect_video # In graph mode this tracks the most recently consumed frame across all # sink queues, not a canonical sink. Callers that need stable per-sink # capture should pass sink_node_id to get_last_frame(). @@ -452,7 +466,10 @@ def remove_media_sink(self, sink: HeadlessMediaSink) -> None: self._media_sinks.remove(sink) def create_ts_streamer(self) -> HeadlessTsStreamer: - streamer = HeadlessTsStreamer(expect_audio=self.expect_audio) + streamer = HeadlessTsStreamer( + expect_audio=self.expect_audio, + expect_video=self.expect_video, + ) self.add_media_sink(streamer) return streamer diff --git a/src/scope/server/mcp_router.py b/src/scope/server/mcp_router.py index e0286293e..2a426593f 100644 --- a/src/scope/server/mcp_router.py +++ b/src/scope/server/mcp_router.py @@ -333,9 +333,25 @@ async def start_stream( detail="FrameProcessor failed to start (check logs for details)", ) + expect_audio = NodeRegistry.chain_produces_audio(pipeline_id_list) + expect_video = True + # Graph-only audio: when no config-driven pipeline declares audio + # but the graph itself carries an audio edge into a sink (e.g. a + # DEMON node graph), the session still expects audio. When the + # graph carries audio but no video into any sink, the headless TS + # streamer must skip the video track entirely. + if request.graph is not None: + graph_has_video_sink, graph_has_audio_sink = ( + graph_config.get_sink_modalities() + ) + if graph_has_audio_sink: + expect_audio = True + if graph_has_audio_sink and not graph_has_video_sink: + expect_video = False session = HeadlessSession( frame_processor=frame_processor, - expect_audio=NodeRegistry.chain_produces_audio(pipeline_id_list), + expect_audio=expect_audio, + expect_video=expect_video, ) session.start_frame_consumer() webrtc_manager.add_headless_session(session) diff --git a/src/scope/server/media_packets.py b/src/scope/server/media_packets.py index 66eb3a13c..c311d6815 100644 --- a/src/scope/server/media_packets.py +++ b/src/scope/server/media_packets.py @@ -60,3 +60,47 @@ def ensure_audio_packet( return item audio, sample_rate = item return AudioPacket(audio=audio, sample_rate=sample_rate) + + +def audio_packet_from_node_output(value: Any) -> AudioPacket | None: + """Build an ``AudioPacket`` from a graph node's audio output port value. + + Accepts the two shapes that built-in and plugin nodes emit: + * ``(tensor, sample_rate)`` tuples (e.g. ``AudioSourceNode``). + * Objects with ``.waveform`` / ``.sample_rate`` / optional + ``.start_sample`` attributes (e.g. ACEStep ``StreamVAEDecode``, + whose ``start_sample`` is forwarded as PTS so ``AudioProcessingTrack`` + can trim overlapping windows downstream). + + Normalizes the tensor for ``AudioProcessingTrack`` consumption: moves + to CPU, drops a leading ``(1, C, T)`` batch dim, and upcasts + bfloat16/float16 to float32 so the subsequent ``.numpy()`` call works. + Returns ``None`` when no audio tensor can be extracted. + """ + start_sample: int | None = None + if isinstance(value, tuple) and len(value) == 2: + audio_tensor, audio_sr = value + else: + audio_tensor = getattr(value, "waveform", None) + audio_sr = getattr(value, "sample_rate", 48000) + start_sample = getattr(value, "start_sample", None) + + if audio_tensor is None: + return None + + if isinstance(audio_tensor, torch.Tensor): + if audio_tensor.is_cuda: + audio_tensor = audio_tensor.detach().cpu() + if audio_tensor.dim() == 3 and audio_tensor.shape[0] == 1: + audio_tensor = audio_tensor.squeeze(0) + if audio_tensor.dtype in (torch.bfloat16, torch.float16): + audio_tensor = audio_tensor.float() + + timestamp = ( + MediaTimestamp(pts=int(start_sample), time_base=Fraction(1, int(audio_sr))) + if start_sample is not None and audio_sr + else MediaTimestamp() + ) + return AudioPacket( + audio=audio_tensor, sample_rate=int(audio_sr), timestamp=timestamp + ) diff --git a/src/scope/server/webrtc.py b/src/scope/server/webrtc.py index 5e865fcbf..0f021cff2 100644 --- a/src/scope/server/webrtc.py +++ b/src/scope/server/webrtc.py @@ -28,6 +28,7 @@ from .cloud_track import CloudTrack from .credentials import get_turn_credentials from .frame_processor import FrameProcessor +from .graph_schema import GraphConfig from .headless import HeadlessSession from .kafka_publisher import publish_event from .livepeer import LivepeerConnection @@ -430,12 +431,28 @@ async def handle_offer( # Create NotificationSender for this session to send notifications to the frontend notification_sender = NotificationSender() - # Determine media modalities from the local pipeline registry - # (authoritative for local mode). initial_parameters values are not - # used here because they may be stale from a previous pipeline load. + # Determine media modalities. Start from the registry (pipelines + # like LTX-2 produce audio internally without exposing a graph + # port). When a graph is present, add anything the graph declares + # via sink edges, and drop video only if the graph is explicitly + # audio-only — that's the DEMON case where ``pipeline_ids`` can + # be a stale ``longlive`` from a previous workflow load. pipeline_ids = initial_parameters.get("pipeline_ids", []) produces_video = NodeRegistry.chain_produces_video(pipeline_ids) produces_audio = NodeRegistry.chain_produces_audio(pipeline_ids) + graph_data = initial_parameters.get("graph") + graph_config = ( + GraphConfig.model_validate(graph_data) + if isinstance(graph_data, dict) + else None + ) + if graph_config is not None and graph_config.get_sink_node_ids(): + graph_video, graph_audio = graph_config.get_sink_modalities() + produces_audio = produces_audio or graph_audio + if graph_audio and not graph_video: + produces_video = False + else: + produces_video = produces_video or graph_video # Parse graph from initial parameters to find sink/source/record node IDs ( diff --git a/tests/test_audio_packets.py b/tests/test_audio_packets.py index 6ed87a1b1..eaa164b75 100644 --- a/tests/test_audio_packets.py +++ b/tests/test_audio_packets.py @@ -8,7 +8,11 @@ from scope.server.cloud_relay import CloudRelay from scope.server.frame_processor import FrameProcessor -from scope.server.media_packets import AudioPacket, MediaTimestamp +from scope.server.media_packets import ( + AudioPacket, + MediaTimestamp, + audio_packet_from_node_output, +) def _make_frame_processor_with_audio_queue(items): @@ -82,3 +86,34 @@ def remove_audio_callback(self, callback): # pragma: no cover - unused in test assert packet is not None assert packet.sample_rate == 48_000 assert packet.timestamp == MediaTimestamp(pts=321, time_base=Fraction(1, 48_000)) + + +def test_audio_packet_from_tuple(): + audio = torch.ones((2, 16)) + packet = audio_packet_from_node_output((audio, 48_000)) + assert packet == AudioPacket(audio=audio, sample_rate=48_000) + + +def test_audio_packet_from_waveform_object_carries_pts(): + """``start_sample`` flows through as PTS so AudioProcessingTrack can trim.""" + audio = torch.zeros((2, 32)) + value = SimpleNamespace(waveform=audio, sample_rate=48_000, start_sample=500) + packet = audio_packet_from_node_output(value) + assert packet is not None + assert packet.sample_rate == 48_000 + assert packet.timestamp == MediaTimestamp(pts=500, time_base=Fraction(1, 48_000)) + + +def test_audio_packet_drops_cuda_batch_dim_and_upcasts_bfloat16(): + # Skip CUDA-only steps; cover shape + dtype normalization (bf16 → fp32 + + # ``(1, C, T)`` squeeze) which AudioProcessingTrack.numpy() requires. + audio = torch.zeros((1, 2, 16), dtype=torch.bfloat16) + packet = audio_packet_from_node_output((audio, 48_000)) + assert packet is not None + assert packet.audio.shape == (2, 16) + assert packet.audio.dtype == torch.float32 + + +def test_audio_packet_from_missing_waveform_is_none(): + value = SimpleNamespace(something_else=True) + assert audio_packet_from_node_output(value) is None diff --git a/tests/test_audio_processing_track.py b/tests/test_audio_processing_track.py index 9e6b14ae2..ccea1471a 100644 --- a/tests/test_audio_processing_track.py +++ b/tests/test_audio_processing_track.py @@ -362,3 +362,83 @@ def test_caps_buffer_at_max(self): # After cap-trimming and consuming one 20ms frame, must be under the cap assert track._buffered_samples <= max_interleaved + + +# --------------------------------------------------------------------------- +# Overlap-trim — streaming VAE decoders (ACEStep StreamVAEDecode + +# follow_playhead) re-emit windows whose start_sample retreats into already +# buffered territory; the track must drop the duplicate prefix. +# --------------------------------------------------------------------------- + + +class TestOverlapTrim: + def _packet(self, start_sample: int, n_samples: int) -> AudioPacket: + return AudioPacket( + audio=torch.ones(2, n_samples), + sample_rate=48000, + timestamp=MediaTimestamp( + pts=start_sample, time_base=fractions.Fraction(1, 48000) + ), + ) + + def test_trims_partial_overlap(self): + """Second window starts 100 samples before the first one ended; + those 100 samples (× 2 channels) are dropped from the front.""" + track = _make_track() + first = self._packet(start_sample=0, n_samples=500) + # Next chunk overlaps last 100 samples of `first`. + second = self._packet(start_sample=400, n_samples=500) + + chunks = iter([first, second, None]) + track.frame_processor.get_audio_packet = MagicMock( + side_effect=lambda: next(chunks, None) + ) + track._drain_audio_packets() + + # First chunk: 500 × 2 = 1000 samples, pts 0. + # Second chunk: trimmed by 100 samples × 2 channels = 200 → 800 samples, + # pts advances to 500 (the next expected sample after the first chunk). + assert len(track._chunks) == 2 + first_buf, first_pts = track._chunks[0] + second_buf, second_pts = track._chunks[1] + assert len(first_buf) == 1000 + assert first_pts == 0 + assert len(second_buf) == 800 + assert second_pts == 500 + assert track._buffered_samples == 1800 + + def test_drops_chunk_entirely_in_past(self): + """Window whose end lies before the current expected PTS is dropped.""" + track = _make_track() + first = self._packet(start_sample=0, n_samples=500) + # Entirely behind the playhead: start=0, end=100 < expected 500. + stale = self._packet(start_sample=0, n_samples=100) + + chunks = iter([first, stale, None]) + track.frame_processor.get_audio_packet = MagicMock( + side_effect=lambda: next(chunks, None) + ) + track._drain_audio_packets() + + assert len(track._chunks) == 1 + assert track._chunks[0][1] == 0 + # _next_expected_pts is still the end of the first chunk: 500. + assert track._next_expected_pts == 500 + + def test_contiguous_chunks_not_trimmed(self): + """Back-to-back windows (no overlap) flow through untouched.""" + track = _make_track() + first = self._packet(start_sample=0, n_samples=500) + second = self._packet(start_sample=500, n_samples=500) + + chunks = iter([first, second, None]) + track.frame_processor.get_audio_packet = MagicMock( + side_effect=lambda: next(chunks, None) + ) + track._drain_audio_packets() + + assert len(track._chunks) == 2 + assert len(track._chunks[0][0]) == 1000 + assert len(track._chunks[1][0]) == 1000 + assert track._chunks[1][1] == 500 + assert track._next_expected_pts == 1000 diff --git a/tests/test_audio_source_node.py b/tests/test_audio_source_node.py new file mode 100644 index 000000000..c113aa757 --- /dev/null +++ b/tests/test_audio_source_node.py @@ -0,0 +1,127 @@ +"""Tests for the built-in AudioSource node and its WAV decoder.""" + +from __future__ import annotations + +import struct +import wave +from pathlib import Path + +import numpy as np +import pytest + +from scope.core.nodes.builtins.audio_io import ( + SAMPLE_RATE, + AudioSourceNode, + _read_wav_float32, +) + + +def _write_pcm16(path: Path, samples: np.ndarray, sample_rate: int) -> None: + """Write float samples to a 16-bit PCM WAV via stdlib ``wave``.""" + pcm = np.clip(samples, -1.0, 1.0) + pcm = (pcm * 32767.0).astype(" None: + """Hand-roll an IEEE-float32 WAV (stdlib ``wave`` rejects format 3).""" + if samples.ndim == 1: + samples = samples[:, None] + n_channels = samples.shape[1] + data = samples.astype(" None: + sr = 16000 + samples = np.random.uniform(-0.5, 0.5, (sr, 2)).astype(np.float32) + wav = tmp_path / "pcm.wav" + _write_pcm16(wav, samples, sr) + + data, decoded_sr = _read_wav_float32(str(wav)) + assert decoded_sr == sr + assert data.shape == (sr, 2) + # 16-bit quantisation noise at most ~1/32767. + assert np.max(np.abs(data - samples)) < 1.5 / 32767 + + +def test_read_wav_float32_mono(tmp_path: Path) -> None: + sr = 22050 + samples = np.random.uniform(-0.9, 0.9, sr).astype(np.float32) + wav = tmp_path / "float.wav" + _write_float32(wav, samples, sr) + + data, decoded_sr = _read_wav_float32(str(wav)) + assert decoded_sr == sr + assert data.shape == (sr, 1) + np.testing.assert_allclose(data[:, 0], samples, atol=1e-6) + + +def test_read_wav_rejects_non_wav(tmp_path: Path) -> None: + bad = tmp_path / "bad.wav" + bad.write_bytes(b"NOT A WAV FILE") + with pytest.raises(ValueError): + _read_wav_float32(str(bad)) + + +def test_audio_source_emits_full_clip(tmp_path: Path) -> None: + """``execute`` returns the entire decoded clip as a (channels, samples) tensor.""" + sr = SAMPLE_RATE + duration_s = 1.0 + samples = np.random.uniform(-0.3, 0.3, (int(sr * duration_s), 2)).astype(np.float32) + wav = tmp_path / "clip.wav" + _write_pcm16(wav, samples, sr) + + node = AudioSourceNode(node_id="audio_src") + out = node.execute({}, file_id=str(wav), duration=duration_s) + + assert "audio" in out + tensor, out_sr = out["audio"] + assert out_sr == SAMPLE_RATE + assert tensor.shape == (2, int(sr * duration_s)) + + +def test_audio_source_missing_file_is_silent(tmp_path: Path) -> None: + """Missing ``file_id`` returns ``{}`` instead of raising.""" + node = AudioSourceNode(node_id="audio_src") + assert node.execute({}) == {} + assert node.execute({}, file_id=str(tmp_path / "does_not_exist.wav")) == {} + + +def test_audio_source_duration_change_retrims(tmp_path: Path) -> None: + """Changing only the duration must re-trim — the cache key includes it.""" + sr = SAMPLE_RATE + samples = np.random.uniform(-0.3, 0.3, (sr * 4, 2)).astype(np.float32) + wav = tmp_path / "clip.wav" + _write_pcm16(wav, samples, sr) + + node = AudioSourceNode(node_id="audio_src") + out_long = node.execute({}, file_id=str(wav), duration=4.0) + assert out_long["audio"][0].shape[1] == 4 * sr + + out_short = node.execute({}, file_id=str(wav), duration=1.0) + assert out_short["audio"][0].shape[1] == 1 * sr diff --git a/tests/test_graph_sink_modalities.py b/tests/test_graph_sink_modalities.py new file mode 100644 index 000000000..25eceee21 --- /dev/null +++ b/tests/test_graph_sink_modalities.py @@ -0,0 +1,142 @@ +"""Tests for ``GraphConfig.get_sink_modalities`` — the consolidated check +used by webrtc/mcp_router to decide what tracks a graph emits. +""" + +from __future__ import annotations + +from scope.server.graph_schema import GraphConfig + + +def _build(nodes: list[dict], edges: list[dict]) -> GraphConfig: + return GraphConfig.model_validate({"nodes": nodes, "edges": edges}) + + +def test_video_only_graph() -> None: + g = _build( + nodes=[ + {"id": "in", "type": "source"}, + {"id": "p", "type": "pipeline", "pipeline_id": "passthrough"}, + {"id": "out", "type": "sink"}, + ], + edges=[ + { + "from": "in", + "from_port": "video", + "to_node": "p", + "to_port": "video", + "kind": "stream", + }, + { + "from": "p", + "from_port": "video", + "to_node": "out", + "to_port": "video", + "kind": "stream", + }, + ], + ) + assert g.get_sink_modalities() == (True, False) + + +def test_audio_only_graph() -> None: + g = _build( + nodes=[ + {"id": "src", "type": "source"}, + {"id": "out", "type": "sink"}, + ], + edges=[ + { + "from": "src", + "from_port": "audio", + "to_node": "out", + "to_port": "audio", + "kind": "stream", + } + ], + ) + assert g.get_sink_modalities() == (False, True) + + +def test_mixed_video_and_audio_sinks() -> None: + g = _build( + nodes=[ + {"id": "src", "type": "source"}, + {"id": "video_out", "type": "sink"}, + {"id": "audio_out", "type": "sink"}, + ], + edges=[ + { + "from": "src", + "from_port": "video", + "to_node": "video_out", + "to_port": "video", + "kind": "stream", + }, + { + "from": "src", + "from_port": "audio", + "to_node": "audio_out", + "to_port": "audio", + "kind": "stream", + }, + ], + ) + assert g.get_sink_modalities() == (True, True) + + +def test_no_sink_returns_false_false() -> None: + g = _build( + nodes=[{"id": "src", "type": "source"}], + edges=[], + ) + assert g.get_sink_modalities() == (False, False) + + +def test_non_stream_edges_ignored() -> None: + # Only "stream" edges count toward sink modalities; "parameter" edges + # carry control data and shouldn't promote a sink to video/audio. + g = _build( + nodes=[ + {"id": "src", "type": "source"}, + {"id": "out", "type": "sink"}, + ], + edges=[ + { + "from": "src", + "from_port": "video", + "to_node": "out", + "to_port": "video", + "kind": "parameter", + } + ], + ) + assert g.get_sink_modalities() == (False, False) + + +def test_edges_not_targeting_sink_ignored() -> None: + # An audio edge between two non-sink nodes must not mark the graph as + # audio-producing — only edges into a sink count. + g = _build( + nodes=[ + {"id": "src", "type": "source"}, + {"id": "encoder", "type": "node", "node_type_id": "audio.Encode"}, + {"id": "out", "type": "sink"}, + ], + edges=[ + { + "from": "src", + "from_port": "audio", + "to_node": "encoder", + "to_port": "audio", + "kind": "stream", + }, + { + "from": "encoder", + "from_port": "video", + "to_node": "out", + "to_port": "video", + "kind": "stream", + }, + ], + ) + assert g.get_sink_modalities() == (True, False)