Skip to content

Commit 60d59d8

Browse files
authored
Cleanup PcmData constructors and populate participant in events (#192)
* cleanup/simplify methods from PcmData * pass participant inside pcm_data
1 parent ece3982 commit 60d59d8

File tree

4 files changed

+173
-106
lines changed

4 files changed

+173
-106
lines changed

getstream/video/rtc/pc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,14 @@ async def on_track(track: aiortc.mediastreams.MediaStreamTrack):
155155
self.track_map[track.id] = (relay, track)
156156

157157
if track.kind == "audio":
158-
# Add a new subscriber for AudioTrackHandler
159-
handler = AudioTrackHandler(
160-
relay.subscribe(track), lambda pcm: self.emit("audio", pcm, user)
161-
)
158+
from getstream.video.rtc import PcmData
159+
160+
# Add a new subscriber for AudioTrackHandler and attach the participant to the pcm object
161+
def _emit_pcm(pcm: PcmData):
162+
pcm.participant = user
163+
self.emit("audio", pcm)
164+
165+
handler = AudioTrackHandler(relay.subscribe(track), _emit_pcm)
162166
asyncio.create_task(handler.start())
163167

164168
self.emit("track_added", relay.subscribe(track), user)

getstream/video/rtc/peer_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ async def setup_subscriber(self):
4343
)
4444

4545
@self.subscriber_pc.on("audio")
46-
async def on_audio(pcm_data, user):
47-
self.connection_manager.emit("audio", pcm_data, user)
46+
async def on_audio(pcm_data):
47+
self.connection_manager.emit("audio", pcm_data)
4848

4949
@self.subscriber_pc.on("track_added")
5050
async def on_track_added(track, user):

getstream/video/rtc/track_util.py

Lines changed: 141 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
f"Dtype mismatch: format='{format}' requires samples with dtype={expected_dtype_name}, "
162162
f"but got dtype={actual_dtype_name}. "
163163
f"To fix: use .to_float32() for f32 format, or ensure samples match the declared format. "
164-
f"For automatic conversion, use PcmData.from_data() instead."
164+
f"For automatic conversion, use PcmData.from_numpy() instead."
165165
)
166166

167167
self.samples: NDArray = samples
@@ -358,77 +358,75 @@ def from_bytes(
358358
)
359359

360360
@classmethod
361-
def from_data(
361+
def from_numpy(
362362
cls,
363-
data: Union[bytes, bytearray, memoryview, NDArray],
363+
array: NDArray,
364364
sample_rate: int = 16000,
365365
format: AudioFormatType = AudioFormat.S16,
366366
channels: int = 1,
367367
) -> "PcmData":
368-
"""Build from bytes or numpy arrays.
368+
"""Build from numpy arrays with automatic dtype/shape conversion.
369369
370370
Args:
371-
data: Input audio data (bytes or numpy array)
371+
array: Input audio data as numpy array
372372
sample_rate: Sample rate in Hz (default: 16000)
373373
format: Audio format (default: AudioFormat.S16)
374374
channels: Number of channels (default: 1 for mono)
375375
376376
Example:
377377
>>> import numpy as np
378-
>>> PcmData.from_data(np.array([1, 2], np.int16), sample_rate=16000, format=AudioFormat.S16, channels=1).channels
378+
>>> PcmData.from_numpy(np.array([1, 2], np.int16), sample_rate=16000, format=AudioFormat.S16, channels=1).channels
379379
1
380380
"""
381381
# Validate format
382382
AudioFormat.validate(format)
383-
if isinstance(data, (bytes, bytearray, memoryview)):
384-
return cls.from_bytes(
385-
bytes(data), sample_rate=sample_rate, format=format, channels=channels
386-
)
387383

388-
if isinstance(data, np.ndarray):
389-
arr = data
390-
# Ensure dtype aligns with format
391-
if format == "s16" and arr.dtype != np.int16:
392-
arr = arr.astype(np.int16)
393-
elif format == "f32" and arr.dtype != np.float32:
394-
arr = arr.astype(np.float32)
384+
if not isinstance(array, np.ndarray):
385+
raise TypeError(
386+
f"from_numpy() expects a numpy array, got {type(array).__name__}. "
387+
f"Use from_bytes() for bytes or from_response() for API responses."
388+
)
395389

396-
# Normalize shape to (channels, samples) for multi-channel
397-
if arr.ndim == 2:
398-
if arr.shape[0] == channels:
399-
samples_arr = arr
400-
elif arr.shape[1] == channels:
401-
samples_arr = arr.T
402-
else:
403-
# Assume first dimension is channels if ambiguous
404-
samples_arr = arr
405-
elif arr.ndim == 1:
406-
if channels > 1:
407-
try:
408-
frames = arr.reshape(-1, channels)
409-
samples_arr = frames.T
410-
except Exception:
411-
logger.warning(
412-
f"Could not reshape 1D array to {channels} channels; keeping mono"
413-
)
414-
channels = 1
415-
samples_arr = arr
416-
else:
390+
arr = array
391+
# Ensure dtype aligns with format
392+
if format == "s16" and arr.dtype != np.int16:
393+
arr = arr.astype(np.int16)
394+
elif format == "f32" and arr.dtype != np.float32:
395+
arr = arr.astype(np.float32)
396+
397+
# Normalize shape to (channels, samples) for multi-channel
398+
if arr.ndim == 2:
399+
if arr.shape[0] == channels:
400+
samples_arr = arr
401+
elif arr.shape[1] == channels:
402+
samples_arr = arr.T
403+
else:
404+
# Assume first dimension is channels if ambiguous
405+
samples_arr = arr
406+
elif arr.ndim == 1:
407+
if channels > 1:
408+
try:
409+
frames = arr.reshape(-1, channels)
410+
samples_arr = frames.T
411+
except Exception:
412+
logger.warning(
413+
f"Could not reshape 1D array to {channels} channels; keeping mono"
414+
)
415+
channels = 1
417416
samples_arr = arr
418417
else:
419-
# Fallback
420-
samples_arr = arr.reshape(-1)
421-
channels = 1
422-
423-
return cls(
424-
samples=samples_arr,
425-
sample_rate=sample_rate,
426-
format=format,
427-
channels=channels,
428-
)
418+
samples_arr = arr
419+
else:
420+
# Fallback
421+
samples_arr = arr.reshape(-1)
422+
channels = 1
429423

430-
# Unsupported type
431-
raise TypeError(f"Unsupported data type for PcmData: {type(data)}")
424+
return cls(
425+
samples=samples_arr,
426+
sample_rate=sample_rate,
427+
format=format,
428+
channels=channels,
429+
)
432430

433431
@classmethod
434432
def from_av_frame(cls, frame: "av.AudioFrame") -> "PcmData":
@@ -885,6 +883,69 @@ def clear(self) -> None:
885883

886884
self.samples = np.array([], dtype=dtype)
887885

886+
@staticmethod
887+
def _calculate_sample_width(format: AudioFormatType) -> int:
888+
"""Calculate bytes per sample for a given format."""
889+
return 2 if format == "s16" else 4 if format == "f32" else 2
890+
891+
@classmethod
892+
def _process_iterator_chunk(
893+
cls,
894+
buf: bytearray,
895+
frame_width: int,
896+
sample_rate: int,
897+
channels: int,
898+
format: AudioFormatType,
899+
) -> tuple[Optional["PcmData"], bytearray]:
900+
"""
901+
Process buffered audio data and return aligned chunk.
902+
903+
Returns:
904+
Tuple of (PcmData chunk or None, remaining buffer)
905+
"""
906+
aligned = (len(buf) // frame_width) * frame_width
907+
if aligned:
908+
chunk = bytes(buf[:aligned])
909+
remaining = buf[aligned:]
910+
pcm = cls.from_bytes(
911+
chunk,
912+
sample_rate=sample_rate,
913+
channels=channels,
914+
format=format,
915+
)
916+
return pcm, bytearray(remaining)
917+
return None, buf
918+
919+
@classmethod
920+
def _finalize_iterator_buffer(
921+
cls,
922+
buf: bytearray,
923+
frame_width: int,
924+
sample_rate: int,
925+
channels: int,
926+
format: AudioFormatType,
927+
) -> Optional["PcmData"]:
928+
"""
929+
Process remaining buffer at end of iteration with padding if needed.
930+
931+
Returns:
932+
Final PcmData chunk or None if buffer is empty
933+
"""
934+
if not buf:
935+
return None
936+
937+
# Pad to frame boundary
938+
pad_len = (-len(buf)) % frame_width
939+
if pad_len:
940+
buf.extend(b"\x00" * pad_len)
941+
942+
return cls.from_bytes(
943+
bytes(buf),
944+
sample_rate=sample_rate,
945+
channels=channels,
946+
format=format,
947+
)
948+
888949
@classmethod
889950
def from_response(
890951
cls,
@@ -926,38 +987,32 @@ def from_response(
926987
if hasattr(response, "__aiter__"):
927988

928989
async def _agen():
929-
width = 2 if format == "s16" else 4 if format == "f32" else 2
990+
width = cls._calculate_sample_width(format)
930991
frame_width = width * max(1, channels)
931992
buf = bytearray()
993+
932994
async for item in response:
933995
if isinstance(item, PcmData):
934996
yield item
935997
continue
998+
936999
data = getattr(item, "data", item)
9371000
if not isinstance(data, (bytes, bytearray, memoryview)):
9381001
raise TypeError("Async iterator yielded unsupported item type")
1002+
9391003
buf.extend(bytes(data))
940-
aligned = (len(buf) // frame_width) * frame_width
941-
if aligned:
942-
chunk = bytes(buf[:aligned])
943-
del buf[:aligned]
944-
yield cls.from_bytes(
945-
chunk,
946-
sample_rate=sample_rate,
947-
channels=channels,
948-
format=format,
949-
)
950-
# pad remainder, if any
951-
if buf:
952-
pad_len = (-len(buf)) % frame_width
953-
if pad_len:
954-
buf.extend(b"\x00" * pad_len)
955-
yield cls.from_bytes(
956-
bytes(buf),
957-
sample_rate=sample_rate,
958-
channels=channels,
959-
format=format,
1004+
chunk, buf = cls._process_iterator_chunk(
1005+
buf, frame_width, sample_rate, channels, format
9601006
)
1007+
if chunk:
1008+
yield chunk
1009+
1010+
# Handle remainder
1011+
final_chunk = cls._finalize_iterator_buffer(
1012+
buf, frame_width, sample_rate, channels, format
1013+
)
1014+
if final_chunk:
1015+
yield final_chunk
9611016

9621017
return _agen()
9631018

@@ -967,37 +1022,32 @@ async def _agen():
9671022
):
9681023

9691024
def _gen():
970-
width = 2 if format == "s16" else 4 if format == "f32" else 2
1025+
width = cls._calculate_sample_width(format)
9711026
frame_width = width * max(1, channels)
9721027
buf = bytearray()
1028+
9731029
for item in response:
9741030
if isinstance(item, PcmData):
9751031
yield item
9761032
continue
1033+
9771034
data = getattr(item, "data", item)
9781035
if not isinstance(data, (bytes, bytearray, memoryview)):
9791036
raise TypeError("Iterator yielded unsupported item type")
1037+
9801038
buf.extend(bytes(data))
981-
aligned = (len(buf) // frame_width) * frame_width
982-
if aligned:
983-
chunk = bytes(buf[:aligned])
984-
del buf[:aligned]
985-
yield cls.from_bytes(
986-
chunk,
987-
sample_rate=sample_rate,
988-
channels=channels,
989-
format=format,
990-
)
991-
if buf:
992-
pad_len = (-len(buf)) % frame_width
993-
if pad_len:
994-
buf.extend(b"\x00" * pad_len)
995-
yield cls.from_bytes(
996-
bytes(buf),
997-
sample_rate=sample_rate,
998-
channels=channels,
999-
format=format,
1039+
chunk, buf = cls._process_iterator_chunk(
1040+
buf, frame_width, sample_rate, channels, format
10001041
)
1042+
if chunk:
1043+
yield chunk
1044+
1045+
# Handle remainder
1046+
final_chunk = cls._finalize_iterator_buffer(
1047+
buf, frame_width, sample_rate, channels, format
1048+
)
1049+
if final_chunk:
1050+
yield final_chunk
10011051

10021052
return _gen()
10031053

0 commit comments

Comments
 (0)