-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add efficient AV decoding #38
base: develop
Are you sure you want to change the base?
Changes from all commits
10c47c6
26700bb
86d4f08
95ea59c
3d384bb
ebe298a
92e7ee4
0916ae0
e2c4aa8
59ce7c2
4a8a544
164b22f
ca16574
6006a03
2394f17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,352 @@ | ||||||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||||||
# SPDX-License-Identifier: BSD-3-Clause | ||||||
|
||||||
import io | ||||||
from collections.abc import Collection, Iterator | ||||||
from fractions import Fraction | ||||||
import re | ||||||
|
||||||
import av | ||||||
import numpy as np | ||||||
import torch | ||||||
import torchaudio | ||||||
|
||||||
from megatron.energon.flavors.webdataset.fastseek import Fastseek | ||||||
from megatron.energon.flavors.webdataset.video_data import VideoData | ||||||
|
||||||
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds | ||||||
|
||||||
class AVDecoder: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should have an alternative decoder as well, which returns the decoder itself, so the user can decide in user code (=encode_sample) which frames to read? Like this: # This function is to be registered as decoder
def read_av_data(key: str, data: bytes):
if key in ("mp3", ...):
return AVData(data)
# This class is now passed to the user's `encode_sample` function (i.e. the raw video
# bytes are essentially passed through). This allows the user to decide on the
# parameters on the fly (e.g. for open-sora).
class AVData:
def __init__(self, raw: bytes):
...
def get_frames(
self,
audio_convert_to_melspec: bool = False,
audio_clip_duration: int = 1,
audio_num_clips: int = -1,
audio_target_rate: int = 16000,
video_decode_audio: bool = False,
video_num_frames: int = 64,
video_out_frame_size: tuple = (224, 224),
) -> AudioVideoData:
... WDYT? |
||||||
def __init__( | ||||||
self, | ||||||
audio_convert_to_melspec, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally, we have all parameters statically typed. Also all class variables are typically typed. |
||||||
audio_clip_duration, | ||||||
audio_num_clips, | ||||||
audio_target_rate, | ||||||
video_decode_audio, | ||||||
video_num_frames, | ||||||
video_out_frame_size, | ||||||
): | ||||||
self.audio_convert_to_melspec = audio_convert_to_melspec | ||||||
self.audio_clip_duration = audio_clip_duration | ||||||
self.audio_num_clips = audio_num_clips | ||||||
self.audio_target_rate = audio_target_rate | ||||||
self.video_decode_audio = video_decode_audio | ||||||
self.video_num_frames = video_num_frames | ||||||
self.video_out_frame_size = video_out_frame_size | ||||||
|
||||||
def __call__(self, key, data): | ||||||
""" | ||||||
Extract the video or audio data from default media extensions. | ||||||
|
||||||
Args: | ||||||
key: media file extension | ||||||
data: raw media bytes | ||||||
""" | ||||||
extension = re.sub(r".*[.]", "", key) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Otherwise, we usually compile regexes before |
||||||
# TODO(jbarker): we should add a debug log here | ||||||
if extension in "mov mp4 webm mkv".split(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# TODO(jbarker): make the magic numbers configurable | ||||||
media = decode_video_frames( | ||||||
data, | ||||||
num_frames=self.video_num_frames, | ||||||
out_frame_size=self.video_out_frame_size, | ||||||
decode_audio=self.video_decode_audio, | ||||||
) | ||||||
elif extension in "flac mp3".split(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# TODO(jbarker): make the magic numbers configurable | ||||||
media = decode_audio_samples( | ||||||
data, | ||||||
convert_to_melspec=self.audio_convert_to_melspec, | ||||||
num_clips=self.audio_num_clips, | ||||||
clip_duration=self.audio_clip_duration, | ||||||
target_rate=self.audio_target_rate, | ||||||
) | ||||||
else: | ||||||
return None | ||||||
if media is not None: | ||||||
return VideoData( | ||||||
frames=media[0].permute((0, 3, 1, 2)), | ||||||
aframes=media[1], | ||||||
info=media[2], | ||||||
) | ||||||
return None | ||||||
|
||||||
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel, the functions below may have their own file, and also reside in the |
||||||
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 | ||||||
waveform -= waveform.mean() | ||||||
fbank = torchaudio.compliance.kaldi.fbank( | ||||||
waveform, | ||||||
htk_compat=True, | ||||||
sample_frequency=sample_rate, | ||||||
use_energy=False, | ||||||
window_type="hanning", | ||||||
num_mel_bins=num_mel_bins, | ||||||
dither=0.0, | ||||||
frame_length=25, | ||||||
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, | ||||||
) | ||||||
# Convert to [mel_bins, num_frames] shape | ||||||
fbank = fbank.transpose(0, 1) | ||||||
# Pad to target_length | ||||||
n_frames = fbank.size(1) | ||||||
p = target_length - n_frames | ||||||
# cut and pad | ||||||
if p > 0: | ||||||
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) | ||||||
elif p < 0: | ||||||
fbank = fbank[:, 0:target_length] | ||||||
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1 | ||||||
# channel image | ||||||
fbank = fbank.unsqueeze(0) | ||||||
return fbank | ||||||
|
||||||
def frame_to_ts(frame: int, average_rate: Fraction, time_base: Fraction) -> int: | ||||||
return int(frame / average_rate / time_base) | ||||||
|
||||||
|
||||||
def ts_to_frame(ts: int, average_rate: Fraction, time_base: Fraction) -> int: | ||||||
return int(ts * time_base * average_rate) | ||||||
|
||||||
|
||||||
def get_frame_batch( | ||||||
video_file: io.BytesIO, | ||||||
frame_indices: Collection[int], | ||||||
out_frame_size: tuple = None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, dict]: | ||||||
"""Gets a batch of frames at the given indices from a video file.""" | ||||||
seeker: Fastseek = Fastseek(video_file) | ||||||
video_file.seek( | ||||||
0 | ||||||
) # Reset the video stream so that pyav can read the entire container | ||||||
|
||||||
with av.open(video_file) as input_container: | ||||||
# Grab video & audio streams | ||||||
video_stream = input_container.streams.video[0] | ||||||
audio_stream = input_container.streams.audio[0] | ||||||
|
||||||
# enable multi-threaded decode for video | ||||||
video_stream.thread_type = 3 | ||||||
|
||||||
# Collect metadata | ||||||
video_fps = float(video_stream.average_rate) if video_stream.average_rate else 0.0 | ||||||
audio_fps = audio_stream.sample_rate or 0 | ||||||
metadata = {"video_fps": video_fps, "audio_fps": audio_fps} | ||||||
|
||||||
# Pre-calculate timing info for video | ||||||
average_rate: Fraction = video_stream.average_rate | ||||||
time_base: Fraction = video_stream.time_base | ||||||
average_frame_duration: int = int(1 / average_rate / time_base) | ||||||
|
||||||
frame_iterator: Iterator[av.VideoFrame] = input_container.decode(video=0) | ||||||
previous_frame_number: int = 0 | ||||||
|
||||||
frames: list[torch.Tensor] = [] | ||||||
# Decode requested video frames | ||||||
frames = [] | ||||||
for target_frame_number in frame_indices: | ||||||
if seeker.mime in ["video/x-matroska", "video/webm"]: | ||||||
# Matroska uses time rather than frame number | ||||||
prev_frame_ts = frame_to_ts( | ||||||
previous_frame_number, average_rate, seeker.container_time_base | ||||||
) | ||||||
target_frame_ts = frame_to_ts( | ||||||
target_frame_number, average_rate, seeker.container_time_base | ||||||
) | ||||||
else: | ||||||
prev_frame_ts = previous_frame_number | ||||||
target_frame_ts = target_frame_number | ||||||
|
||||||
target_pts = frame_to_ts(target_frame_number, average_rate, time_base) | ||||||
|
||||||
if seeker.should_seek(prev_frame_ts, target_frame_ts): | ||||||
input_container.seek(target_pts, stream=video_stream) | ||||||
|
||||||
for frame in frame_iterator: | ||||||
if ( | ||||||
frame.pts | ||||||
<= target_pts + (average_frame_duration / 2) | ||||||
<= frame.pts + average_frame_duration | ||||||
): | ||||||
if out_frame_size is not None: | ||||||
frame = frame.reformat( | ||||||
width=out_frame_size[0], | ||||||
height=out_frame_size[1], | ||||||
format="rgb24", | ||||||
interpolation="BILINEAR", | ||||||
) | ||||||
else: | ||||||
frame = frame.reformat(format="rgb24") | ||||||
frames.append(torch.from_numpy(frame.to_ndarray())) | ||||||
break | ||||||
|
||||||
previous_frame_number = target_frame_number | ||||||
|
||||||
# Stack video frames along dim=0 => [batch_size, channels, height, width] | ||||||
video_tensor = torch.stack(frames) | ||||||
|
||||||
return video_tensor, metadata | ||||||
|
||||||
|
||||||
def decode_video_frames( | ||||||
data: bytes, | ||||||
num_frames: int = -1, | ||||||
out_frame_size: tuple = None, | ||||||
decode_audio: bool = False, | ||||||
num_clips: int = 1, | ||||||
clip_duration: int = 1, | ||||||
target_rate: int = 16000, | ||||||
convert_to_melspec: bool = False, | ||||||
): | ||||||
byte_stream = io.BytesIO(data) | ||||||
|
||||||
# --- First, decode video frames --- | ||||||
with av.open(byte_stream) as input_container: | ||||||
if input_container.streams.video[0].frames != 0: | ||||||
frame_count = input_container.streams.video[0].frames | ||||||
else: | ||||||
frame_count = len([p for p in input_container.demux(video=0) if p.pts is not None]) | ||||||
|
||||||
if num_frames == -1: | ||||||
num_frames = frame_count | ||||||
|
||||||
# Pick which video frames to extract | ||||||
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int).tolist() | ||||||
video_tensor, metadata = get_frame_batch( | ||||||
byte_stream, frame_indices, out_frame_size | ||||||
) | ||||||
|
||||||
# --- Then, if requested, decode audio using the same clip logic as decode_audio_samples --- | ||||||
audio_tensor = torch.empty(0) | ||||||
if decode_audio: | ||||||
# Open the container again to get sample_count and sampling_rate | ||||||
with av.open(io.BytesIO(data)) as input_container: | ||||||
audio_stream = input_container.streams.audio[0] | ||||||
sample_count = audio_stream.duration | ||||||
sampling_rate = audio_stream.rate | ||||||
|
||||||
if num_clips == -1: | ||||||
# Single clip from the entire audio | ||||||
clip_indices = [[0, sample_count - 1]] | ||||||
else: | ||||||
clip_indices = get_clip_indices( | ||||||
sampling_rate, sample_count, num_clips, clip_duration | ||||||
) | ||||||
|
||||||
# Actually read the audio clips | ||||||
audio_tensor, audio_metadata = get_audio_batch( | ||||||
io.BytesIO(data), | ||||||
clip_indices, | ||||||
target_rate=target_rate, | ||||||
convert_to_melspec=convert_to_melspec, | ||||||
) | ||||||
# Merge any extra audio metadata | ||||||
metadata.update(audio_metadata) | ||||||
|
||||||
return video_tensor, audio_tensor, metadata | ||||||
|
||||||
|
||||||
def get_audio_batch( | ||||||
audio_file: io.BytesIO, | ||||||
clip_indices: list[list[int]], | ||||||
target_rate: int = 16000, | ||||||
convert_to_melspec: bool = False, | ||||||
) -> tuple[torch.Tensor, dict]: | ||||||
""" | ||||||
Gets a batch of audio samples at the given indices from an audio file, | ||||||
resampled to target_rate. Indices correspond to the original sample rate. | ||||||
""" | ||||||
audio_file.seek(0) | ||||||
|
||||||
with av.open(audio_file) as input_container: | ||||||
audio_stream = input_container.streams.audio[0] | ||||||
orig_rate = audio_stream.sample_rate | ||||||
duration_per_sample = 1 / orig_rate | ||||||
metadata = {"audio_fps": orig_rate} | ||||||
|
||||||
# Initialize resampler to convert each frame to target_rate | ||||||
if target_rate != orig_rate: | ||||||
resampler = av.audio.resampler.AudioResampler( | ||||||
format=audio_stream.format, | ||||||
layout=audio_stream.layout, | ||||||
rate=target_rate | ||||||
) | ||||||
|
||||||
clips = [] | ||||||
|
||||||
for indices in clip_indices: | ||||||
start_time = indices[0] * duration_per_sample | ||||||
end_time = indices[-1] * duration_per_sample | ||||||
|
||||||
# Seek near start time (convert to microseconds per PyAV docs) | ||||||
input_container.seek(int(start_time * av.time_base)) | ||||||
|
||||||
decoded_samples = [] | ||||||
for frame in input_container.decode(audio=0): | ||||||
frame_start = frame.pts * frame.time_base | ||||||
# Stop decoding if we've passed the end | ||||||
if frame_start >= end_time: | ||||||
break | ||||||
|
||||||
# Resample this frame to target_rate if necessary | ||||||
if target_rate != orig_rate: | ||||||
frame = resampler.resample(frame)[0] | ||||||
frame_nd = frame.to_ndarray() # (channels, samples) | ||||||
decoded_samples.append(frame_nd) | ||||||
|
||||||
if decoded_samples: | ||||||
# Combine all channels/samples into one array | ||||||
clip_all = np.concatenate(decoded_samples, axis=-1) # (channels, total_samples) | ||||||
|
||||||
# Figure out how many samples in the target rate we want | ||||||
clip_duration_s = (indices[-1] - indices[0] + 1) / orig_rate | ||||||
needed_samples = int(round(clip_duration_s * target_rate)) | ||||||
|
||||||
# Trim or pad as needed | ||||||
clip_all = clip_all[0, :needed_samples] | ||||||
|
||||||
# Convert to torch | ||||||
clip_tensor = torch.from_numpy(clip_all) | ||||||
if convert_to_melspec: | ||||||
clip_tensor = waveform2melspec(clip_tensor.float()[None, :], 16000, 128, 204)[0] | ||||||
clips.append(clip_tensor) | ||||||
|
||||||
return torch.stack(clips), metadata | ||||||
|
||||||
|
||||||
def get_clip_indices(sampling_rate, total_samples, num_clips, clip_duration_sec): | ||||||
clip_samples = int(sampling_rate * clip_duration_sec) | ||||||
assert clip_samples <= total_samples, \ | ||||||
"Requested clip duration exceeds total samples." | ||||||
|
||||||
if num_clips == 1: | ||||||
return [np.arange(0, clip_samples)] | ||||||
|
||||||
# If total length can accommodate all clips without overlap, space them out evenly | ||||||
if num_clips * clip_samples <= total_samples: | ||||||
spacing = total_samples // num_clips | ||||||
else: | ||||||
# Overlap: distribute clips so first starts at 0 and last ends at total_samples - clip_samples | ||||||
spacing = (total_samples - clip_samples) // (num_clips - 1) | ||||||
|
||||||
start_indices = [i * spacing for i in range(num_clips)] | ||||||
return [np.arange(start, start + clip_samples) for start in start_indices] | ||||||
|
||||||
|
||||||
def decode_audio_samples(data: bytes, num_clips: int = 1, clip_duration: int = 1, target_rate: int = 16000, convert_to_melspec: bool = False): | ||||||
|
||||||
byte_stream = io.BytesIO(data) | ||||||
|
||||||
with av.open(byte_stream) as input_container: | ||||||
sample_count = input_container.streams.audio[0].duration | ||||||
sampling_rate = input_container.streams.audio[0].rate | ||||||
|
||||||
if num_clips == -1: | ||||||
num_clips = 1 | ||||||
clip_indices = [[0, sample_count - 1]] | ||||||
else: | ||||||
clip_indices = get_clip_indices(sampling_rate, sample_count, num_clips, clip_duration) | ||||||
|
||||||
audio_tensor, metadata = get_audio_batch(byte_stream, clip_indices, target_rate, convert_to_melspec) | ||||||
|
||||||
return None, audio_tensor, metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add these as extras to be installed like
pip install megatron-energon[video]
?Would require the imports in the python files to be allowed to fail though