Skip to content

Commit 12c8694

Browse files
authored
New audio track (#188)
* new audio track * remove useless test
1 parent 83fb804 commit 12c8694

File tree

2 files changed

+433
-173
lines changed

2 files changed

+433
-173
lines changed

getstream/video/rtc/audio_track.py

Lines changed: 118 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
class AudioStreamTrack(aiortc.mediastreams.MediaStreamTrack):
1616
"""
17-
Audio stream track that accepts PcmData objects directly from a queue.
17+
Audio stream track that accepts PcmData objects and buffers them as bytes.
1818
1919
Works with PcmData objects instead of raw bytes, avoiding format conversion issues.
20+
Internally buffers as bytes for efficient memory usage.
2021
2122
Usage:
2223
track = AudioStreamTrack(sample_rate=48000, channels=2)
@@ -34,7 +35,7 @@ def __init__(
3435
sample_rate: int = 48000,
3536
channels: int = 1,
3637
format: str = "s16",
37-
max_queue_size: int = 100,
38+
audio_buffer_size_ms: int = 30000, # 30 seconds default
3839
):
3940
"""
4041
Initialize an AudioStreamTrack that accepts PcmData objects.
@@ -43,92 +44,119 @@ def __init__(
4344
sample_rate: Target sample rate in Hz (default: 48000)
4445
channels: Number of channels - 1=mono, 2=stereo (default: 1)
4546
format: Audio format - "s16" or "f32" (default: "s16")
46-
max_queue_size: Maximum number of PcmData objects in queue (default: 100)
47+
audio_buffer_size_ms: Maximum buffer size in milliseconds (default: 30000ms = 30s)
4748
"""
4849
super().__init__()
4950
self.sample_rate = sample_rate
5051
self.channels = channels
5152
self.format = format
52-
self.max_queue_size = max_queue_size
53+
self.audio_buffer_size_ms = audio_buffer_size_ms
5354

5455
logger.debug(
5556
"Initialized AudioStreamTrack",
5657
extra={
5758
"sample_rate": sample_rate,
5859
"channels": channels,
5960
"format": format,
60-
"max_queue_size": max_queue_size,
61+
"audio_buffer_size_ms": audio_buffer_size_ms,
6162
},
6263
)
6364

64-
# Create async queue for PcmData objects
65-
self._queue = asyncio.Queue()
65+
# Internal bytearray buffer for audio data
66+
self._buffer = bytearray()
67+
self._buffer_lock = asyncio.Lock()
68+
69+
# Timing for frame pacing
6670
self._start = None
6771
self._timestamp = None
68-
69-
# Buffer for chunks smaller than 20ms
70-
self._buffer = None
72+
self._last_frame_time = None
73+
74+
# Calculate bytes per sample based on format
75+
self._bytes_per_sample = 2 if format == "s16" else 4 # s16=2 bytes, f32=4 bytes
76+
self._bytes_per_frame = int(
77+
aiortc.mediastreams.AUDIO_PTIME
78+
* self.sample_rate
79+
* self.channels
80+
* self._bytes_per_sample
81+
)
7182

7283
async def write(self, pcm: PcmData):
7384
"""
74-
Add PcmData to the queue.
85+
Add PcmData to the buffer.
7586
7687
The PcmData will be automatically resampled/converted to match
77-
the track's configured sample_rate, channels, and format.
88+
the track's configured sample_rate, channels, and format,
89+
then converted to bytes and stored in the buffer.
7890
7991
Args:
8092
pcm: PcmData object with audio data
8193
"""
82-
# Check if queue is getting too large and trim if necessary
83-
if self._queue.qsize() >= self.max_queue_size:
84-
dropped_items = 0
85-
while self._queue.qsize() >= self.max_queue_size:
86-
try:
87-
self._queue.get_nowait()
88-
self._queue.task_done()
89-
dropped_items += 1
90-
except asyncio.QueueEmpty:
91-
break
92-
93-
logger.warning(
94-
"Audio queue overflow, dropped items max is %d. pcm duration %s ms",
95-
self.max_queue_size,
96-
pcm.duration_ms,
97-
extra={
98-
"dropped_items": dropped_items,
99-
"queue_size": self._queue.qsize(),
100-
},
94+
# Normalize the PCM data to target format immediately
95+
pcm_normalized = self._normalize_pcm(pcm)
96+
97+
# Convert to bytes
98+
audio_bytes = pcm_normalized.to_bytes()
99+
100+
async with self._buffer_lock:
101+
# Check buffer size before adding
102+
max_buffer_bytes = int(
103+
(self.audio_buffer_size_ms / 1000)
104+
* self.sample_rate
105+
* self.channels
106+
* self._bytes_per_sample
101107
)
102108

103-
await self._queue.put(pcm)
109+
# Add new data to buffer first
110+
self._buffer.extend(audio_bytes)
111+
112+
# Check if we exceeded the limit
113+
if len(self._buffer) > max_buffer_bytes:
114+
# Calculate how many bytes to drop from the beginning
115+
bytes_to_drop = len(self._buffer) - max_buffer_bytes
116+
dropped_ms = (
117+
bytes_to_drop
118+
/ (self.sample_rate * self.channels * self._bytes_per_sample)
119+
) * 1000
120+
121+
# TODO: do not perform logging inside critical section, move this code outside the lock
122+
logger.warning(
123+
"Audio buffer overflow, dropping %.1fms of audio. Buffer max is %dms",
124+
dropped_ms,
125+
self.audio_buffer_size_ms,
126+
extra={
127+
"buffer_size_bytes": len(self._buffer),
128+
"incoming_bytes": len(audio_bytes),
129+
"dropped_bytes": bytes_to_drop,
130+
},
131+
)
132+
133+
# Drop from the beginning of the buffer to keep latest data
134+
self._buffer = self._buffer[bytes_to_drop:]
135+
136+
buffer_duration_ms = (
137+
len(self._buffer)
138+
/ (self.sample_rate * self.channels * self._bytes_per_sample)
139+
) * 1000
140+
104141
logger.debug(
105-
"Added PcmData to queue",
142+
"Added audio to buffer",
106143
extra={
107-
"pcm_samples": len(pcm.samples)
108-
if pcm.samples.ndim == 1
109-
else pcm.samples.shape,
110-
"pcm_sample_rate": pcm.sample_rate,
111-
"pcm_channels": pcm.channels,
112-
"queue_size": self._queue.qsize(),
144+
"pcm_duration_ms": pcm.duration_ms,
145+
"buffer_duration_ms": buffer_duration_ms,
146+
"buffer_size_bytes": len(self._buffer),
113147
},
114148
)
115149

116150
async def flush(self) -> None:
117151
"""
118-
Clear any pending audio from the queue and buffer.
152+
Clear any pending audio from the buffer.
119153
Playback stops immediately.
120154
"""
121-
cleared = 0
122-
while not self._queue.empty():
123-
try:
124-
self._queue.get_nowait()
125-
self._queue.task_done()
126-
cleared += 1
127-
except asyncio.QueueEmpty:
128-
break
129-
130-
self._buffer = None
131-
logger.debug("Flushed audio queue", extra={"cleared_items": cleared})
155+
async with self._buffer_lock:
156+
bytes_cleared = len(self._buffer)
157+
self._buffer.clear()
158+
159+
logger.debug("Flushed audio buffer", extra={"cleared_bytes": bytes_cleared})
132160

133161
async def recv(self) -> Frame:
134162
"""
@@ -142,23 +170,56 @@ async def recv(self) -> Frame:
142170

143171
# Calculate samples needed for 20ms frame
144172
samples_per_frame = int(aiortc.mediastreams.AUDIO_PTIME * self.sample_rate)
173+
wakeup_time = 0
145174

146175
# Initialize timestamp if not already done
147176
if self._timestamp is None:
148177
self._start = time.time()
149178
self._timestamp = 0
179+
self._last_frame_time = time.time()
150180
else:
181+
# Use timestamp-based pacing to avoid drift over time
182+
# This ensures we stay synchronized with the expected audio rate
183+
# even if individual frames have slight timing variations
151184
self._timestamp += samples_per_frame
152185
start_ts = self._start or time.time()
153186
wait = start_ts + (self._timestamp / self.sample_rate) - time.time()
154187
if wait > 0:
188+
wakeup_time += time.time() + wait
155189
await asyncio.sleep(wait)
190+
if time.time() - wakeup_time > 0.1:
191+
logger.warning(
192+
f"scheduled sleep took {time.time() - wakeup_time} instead of {wait}, this can happen if there is something blocking the main event-loop, you can enable --debug mode to see blocking traces"
193+
)
156194

157-
# Get or accumulate PcmData to fill a 20ms frame
158-
pcm_for_frame = await self._get_pcm_for_frame(samples_per_frame)
195+
self._last_frame_time = time.time()
196+
197+
# Get 20ms of audio data from buffer
198+
async with self._buffer_lock:
199+
if len(self._buffer) >= self._bytes_per_frame:
200+
# We have enough data
201+
audio_bytes = bytes(self._buffer[: self._bytes_per_frame])
202+
self._buffer = self._buffer[self._bytes_per_frame :]
203+
elif len(self._buffer) > 0:
204+
# We have some data but not enough - pad with silence
205+
audio_bytes = bytes(self._buffer)
206+
padding_needed = self._bytes_per_frame - len(audio_bytes)
207+
audio_bytes += bytes(padding_needed) # Pad with zeros (silence)
208+
self._buffer.clear()
209+
210+
logger.debug(
211+
"Padded audio frame with silence",
212+
extra={
213+
"available_bytes": len(audio_bytes) - padding_needed,
214+
"required_bytes": self._bytes_per_frame,
215+
"padding_bytes": padding_needed,
216+
},
217+
)
218+
else:
219+
# No data at all - emit silence
220+
audio_bytes = bytes(self._bytes_per_frame)
159221

160222
# Create AudioFrame
161-
# Determine layout and format
162223
layout = "stereo" if self.channels == 2 else "mono"
163224

164225
# Convert format name: "s16" -> "s16", "f32" -> "flt"
@@ -172,20 +233,7 @@ async def recv(self) -> Frame:
172233
frame = AudioFrame(format=av_format, layout=layout, samples=samples_per_frame)
173234

174235
# Fill frame with data
175-
if pcm_for_frame is not None:
176-
audio_bytes = pcm_for_frame.to_bytes()
177-
178-
# Write to the single plane (packed format has 1 plane)
179-
if len(audio_bytes) >= frame.planes[0].buffer_size:
180-
frame.planes[0].update(audio_bytes[: frame.planes[0].buffer_size])
181-
else:
182-
# Pad with silence if not enough data
183-
padding = bytes(frame.planes[0].buffer_size - len(audio_bytes))
184-
frame.planes[0].update(audio_bytes + padding)
185-
else:
186-
# No data available, return silence
187-
for plane in frame.planes:
188-
plane.update(bytes(plane.buffer_size))
236+
frame.planes[0].update(audio_bytes)
189237

190238
# Set frame properties
191239
frame.pts = self._timestamp
@@ -194,108 +242,6 @@ async def recv(self) -> Frame:
194242

195243
return frame
196244

197-
async def _get_pcm_for_frame(self, samples_needed: int) -> PcmData | None:
198-
"""
199-
Get or accumulate PcmData to fill exactly samples_needed samples.
200-
201-
This method handles:
202-
- Buffering partial chunks
203-
- Resampling to target sample rate
204-
- Converting to target channels
205-
- Converting to target format
206-
- Chunking to exact frame size
207-
208-
Args:
209-
samples_needed: Number of samples needed for the frame
210-
211-
Returns:
212-
PcmData with exactly samples_needed samples, or None if no data available
213-
"""
214-
# Start with buffered data if any
215-
if self._buffer is not None:
216-
pcm_accumulated = self._buffer
217-
self._buffer = None
218-
else:
219-
pcm_accumulated = None
220-
221-
# Try to get data from queue
222-
try:
223-
# Don't wait too long - if no data, return silence
224-
while True:
225-
# Check if we have enough samples
226-
if pcm_accumulated is not None:
227-
current_samples = (
228-
len(pcm_accumulated.samples)
229-
if pcm_accumulated.samples.ndim == 1
230-
else pcm_accumulated.samples.shape[1]
231-
)
232-
if current_samples >= samples_needed:
233-
break
234-
235-
# Try to get more data
236-
if self._queue.empty():
237-
# No more data available
238-
break
239-
240-
pcm_chunk = await asyncio.wait_for(self._queue.get(), timeout=0.01)
241-
self._queue.task_done()
242-
243-
# Resample/convert to target format
244-
pcm_chunk = self._normalize_pcm(pcm_chunk)
245-
246-
# Accumulate
247-
if pcm_accumulated is None:
248-
pcm_accumulated = pcm_chunk
249-
else:
250-
pcm_accumulated = pcm_accumulated.append(pcm_chunk)
251-
252-
except asyncio.TimeoutError:
253-
pass
254-
255-
# If no data at all, return None (will produce silence)
256-
if pcm_accumulated is None:
257-
return None
258-
259-
# Get the number of samples we have
260-
current_samples = (
261-
len(pcm_accumulated.samples)
262-
if pcm_accumulated.samples.ndim == 1
263-
else pcm_accumulated.samples.shape[1]
264-
)
265-
266-
# If we have exactly the right amount, return it
267-
if current_samples == samples_needed:
268-
return pcm_accumulated
269-
270-
# If we have more than needed, split it
271-
if current_samples > samples_needed:
272-
# Calculate duration needed in seconds
273-
duration_needed_s = samples_needed / self.sample_rate
274-
275-
# Use head() to get exactly what we need
276-
pcm_for_frame = pcm_accumulated.head(
277-
duration_s=duration_needed_s, pad=False, pad_at="end"
278-
)
279-
280-
# Calculate what's left in seconds
281-
duration_used_s = (
282-
len(pcm_for_frame.samples)
283-
if pcm_for_frame.samples.ndim == 1
284-
else pcm_for_frame.samples.shape[1]
285-
) / self.sample_rate
286-
287-
# Buffer the rest
288-
self._buffer = pcm_accumulated.tail(
289-
duration_s=pcm_accumulated.duration - duration_used_s,
290-
pad=False,
291-
pad_at="start",
292-
)
293-
294-
return pcm_for_frame
295-
296-
# If we have less than needed, return what we have (will be padded with silence)
297-
return pcm_accumulated
298-
299245
def _normalize_pcm(self, pcm: PcmData) -> PcmData:
300246
"""
301247
Normalize PcmData to match the track's target format.
@@ -306,9 +252,8 @@ def _normalize_pcm(self, pcm: PcmData) -> PcmData:
306252
Returns:
307253
PcmData resampled/converted to target sample_rate, channels, and format
308254
"""
309-
# Resample to target sample rate and channels if needed
310-
if pcm.sample_rate != self.sample_rate or pcm.channels != self.channels:
311-
pcm = pcm.resample(self.sample_rate, target_channels=self.channels)
255+
256+
pcm = pcm.resample(self.sample_rate, target_channels=self.channels)
312257

313258
# Convert format if needed
314259
if self.format == "s16" and pcm.format != "s16":

0 commit comments

Comments
 (0)