Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add efficient AV decoding #38

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ dependencies = [
"tqdm",
"webdataset",
"s3fs",
"bitstring>=4.2.3",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add these as extras to be installed like pip install megatron-energon[video]?
Would require the imports in the python files to be allowed to fail though

"sortedcontainers>=2.4.0",
"filetype>=1.2.0",
"ebmlite>=3.3.1",
"av>=13.0.0",
]

[project.optional-dependencies]
Expand Down
352 changes: 352 additions & 0 deletions src/megatron/energon/flavors/webdataset/decode_av_frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import io
from collections.abc import Collection, Iterator
from fractions import Fraction
import re

import av
import numpy as np
import torch
import torchaudio

from megatron.energon.flavors.webdataset.fastseek import Fastseek
from megatron.energon.flavors.webdataset.video_data import VideoData

DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds

class AVDecoder:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have an alternative decoder as well, which returns the decoder itself, so the user can decide in user code (=encode_sample) which frames to read?

Like this:

# This function is to be registered as decoder
def read_av_data(key: str, data: bytes):
    if key in ("mp3", ...):
        return AVData(data)


# This class is now passed to the user's `encode_sample` function (i.e. the raw video
# bytes are essentially passed through). This allows the user to decide on the 
# parameters on the fly (e.g. for open-sora).
class AVData:
    def __init__(self, raw: bytes):
        ...

    def get_frames(
            self,
            audio_convert_to_melspec: bool = False,
            audio_clip_duration: int = 1,
            audio_num_clips: int = -1,
            audio_target_rate: int = 16000,
            video_decode_audio: bool = False,
            video_num_frames: int = 64,
            video_out_frame_size: tuple = (224, 224),
        ) -> AudioVideoData:
            ...

WDYT?

def __init__(
self,
audio_convert_to_melspec,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, we have all parameters statically typed. Also all class variables are typically typed.

audio_clip_duration,
audio_num_clips,
audio_target_rate,
video_decode_audio,
video_num_frames,
video_out_frame_size,
):
self.audio_convert_to_melspec = audio_convert_to_melspec
self.audio_clip_duration = audio_clip_duration
self.audio_num_clips = audio_num_clips
self.audio_target_rate = audio_target_rate
self.video_decode_audio = video_decode_audio
self.video_num_frames = video_num_frames
self.video_out_frame_size = video_out_frame_size

def __call__(self, key, data):
"""
Extract the video or audio data from default media extensions.

Args:
key: media file extension
data: raw media bytes
"""
extension = re.sub(r".*[.]", "", key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
extension = re.sub(r".*[.]", "", key)
extension = key.rsplit('.', 1)[-1]

Otherwise, we usually compile regexes before

# TODO(jbarker): we should add a debug log here
if extension in "mov mp4 webm mkv".split():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if extension in "mov mp4 webm mkv".split():
if extension in ("mov", "mp4", "webm", "mkv"):

# TODO(jbarker): make the magic numbers configurable
media = decode_video_frames(
data,
num_frames=self.video_num_frames,
out_frame_size=self.video_out_frame_size,
decode_audio=self.video_decode_audio,
)
elif extension in "flac mp3".split():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif extension in "flac mp3".split():
elif extension in ("flac", "mp3"):

# TODO(jbarker): make the magic numbers configurable
media = decode_audio_samples(
data,
convert_to_melspec=self.audio_convert_to_melspec,
num_clips=self.audio_num_clips,
clip_duration=self.audio_clip_duration,
target_rate=self.audio_target_rate,
)
else:
return None
if media is not None:
return VideoData(
frames=media[0].permute((0, 3, 1, 2)),
aframes=media[1],
info=media[2],
)
return None

def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel, the functions below may have their own file, and also reside in the fastseek package?

# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
waveform -= waveform.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=sample_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=num_mel_bins,
dither=0.0,
frame_length=25,
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
)
# Convert to [mel_bins, num_frames] shape
fbank = fbank.transpose(0, 1)
# Pad to target_length
n_frames = fbank.size(1)
p = target_length - n_frames
# cut and pad
if p > 0:
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
elif p < 0:
fbank = fbank[:, 0:target_length]
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
# channel image
fbank = fbank.unsqueeze(0)
return fbank

def frame_to_ts(frame: int, average_rate: Fraction, time_base: Fraction) -> int:
return int(frame / average_rate / time_base)


def ts_to_frame(ts: int, average_rate: Fraction, time_base: Fraction) -> int:
return int(ts * time_base * average_rate)


def get_frame_batch(
video_file: io.BytesIO,
frame_indices: Collection[int],
out_frame_size: tuple = None,
) -> tuple[torch.Tensor, torch.Tensor, dict]:
"""Gets a batch of frames at the given indices from a video file."""
seeker: Fastseek = Fastseek(video_file)
video_file.seek(
0
) # Reset the video stream so that pyav can read the entire container

with av.open(video_file) as input_container:
# Grab video & audio streams
video_stream = input_container.streams.video[0]
audio_stream = input_container.streams.audio[0]

# enable multi-threaded decode for video
video_stream.thread_type = 3

# Collect metadata
video_fps = float(video_stream.average_rate) if video_stream.average_rate else 0.0
audio_fps = audio_stream.sample_rate or 0
metadata = {"video_fps": video_fps, "audio_fps": audio_fps}

# Pre-calculate timing info for video
average_rate: Fraction = video_stream.average_rate
time_base: Fraction = video_stream.time_base
average_frame_duration: int = int(1 / average_rate / time_base)

frame_iterator: Iterator[av.VideoFrame] = input_container.decode(video=0)
previous_frame_number: int = 0

frames: list[torch.Tensor] = []
# Decode requested video frames
frames = []
for target_frame_number in frame_indices:
if seeker.mime in ["video/x-matroska", "video/webm"]:
# Matroska uses time rather than frame number
prev_frame_ts = frame_to_ts(
previous_frame_number, average_rate, seeker.container_time_base
)
target_frame_ts = frame_to_ts(
target_frame_number, average_rate, seeker.container_time_base
)
else:
prev_frame_ts = previous_frame_number
target_frame_ts = target_frame_number

target_pts = frame_to_ts(target_frame_number, average_rate, time_base)

if seeker.should_seek(prev_frame_ts, target_frame_ts):
input_container.seek(target_pts, stream=video_stream)

for frame in frame_iterator:
if (
frame.pts
<= target_pts + (average_frame_duration / 2)
<= frame.pts + average_frame_duration
):
if out_frame_size is not None:
frame = frame.reformat(
width=out_frame_size[0],
height=out_frame_size[1],
format="rgb24",
interpolation="BILINEAR",
)
else:
frame = frame.reformat(format="rgb24")
frames.append(torch.from_numpy(frame.to_ndarray()))
break

previous_frame_number = target_frame_number

# Stack video frames along dim=0 => [batch_size, channels, height, width]
video_tensor = torch.stack(frames)

return video_tensor, metadata


def decode_video_frames(
data: bytes,
num_frames: int = -1,
out_frame_size: tuple = None,
decode_audio: bool = False,
num_clips: int = 1,
clip_duration: int = 1,
target_rate: int = 16000,
convert_to_melspec: bool = False,
):
byte_stream = io.BytesIO(data)

# --- First, decode video frames ---
with av.open(byte_stream) as input_container:
if input_container.streams.video[0].frames != 0:
frame_count = input_container.streams.video[0].frames
else:
frame_count = len([p for p in input_container.demux(video=0) if p.pts is not None])

if num_frames == -1:
num_frames = frame_count

# Pick which video frames to extract
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int).tolist()
video_tensor, metadata = get_frame_batch(
byte_stream, frame_indices, out_frame_size
)

# --- Then, if requested, decode audio using the same clip logic as decode_audio_samples ---
audio_tensor = torch.empty(0)
if decode_audio:
# Open the container again to get sample_count and sampling_rate
with av.open(io.BytesIO(data)) as input_container:
audio_stream = input_container.streams.audio[0]
sample_count = audio_stream.duration
sampling_rate = audio_stream.rate

if num_clips == -1:
# Single clip from the entire audio
clip_indices = [[0, sample_count - 1]]
else:
clip_indices = get_clip_indices(
sampling_rate, sample_count, num_clips, clip_duration
)

# Actually read the audio clips
audio_tensor, audio_metadata = get_audio_batch(
io.BytesIO(data),
clip_indices,
target_rate=target_rate,
convert_to_melspec=convert_to_melspec,
)
# Merge any extra audio metadata
metadata.update(audio_metadata)

return video_tensor, audio_tensor, metadata


def get_audio_batch(
audio_file: io.BytesIO,
clip_indices: list[list[int]],
target_rate: int = 16000,
convert_to_melspec: bool = False,
) -> tuple[torch.Tensor, dict]:
"""
Gets a batch of audio samples at the given indices from an audio file,
resampled to target_rate. Indices correspond to the original sample rate.
"""
audio_file.seek(0)

with av.open(audio_file) as input_container:
audio_stream = input_container.streams.audio[0]
orig_rate = audio_stream.sample_rate
duration_per_sample = 1 / orig_rate
metadata = {"audio_fps": orig_rate}

# Initialize resampler to convert each frame to target_rate
if target_rate != orig_rate:
resampler = av.audio.resampler.AudioResampler(
format=audio_stream.format,
layout=audio_stream.layout,
rate=target_rate
)

clips = []

for indices in clip_indices:
start_time = indices[0] * duration_per_sample
end_time = indices[-1] * duration_per_sample

# Seek near start time (convert to microseconds per PyAV docs)
input_container.seek(int(start_time * av.time_base))

decoded_samples = []
for frame in input_container.decode(audio=0):
frame_start = frame.pts * frame.time_base
# Stop decoding if we've passed the end
if frame_start >= end_time:
break

# Resample this frame to target_rate if necessary
if target_rate != orig_rate:
frame = resampler.resample(frame)[0]
frame_nd = frame.to_ndarray() # (channels, samples)
decoded_samples.append(frame_nd)

if decoded_samples:
# Combine all channels/samples into one array
clip_all = np.concatenate(decoded_samples, axis=-1) # (channels, total_samples)

# Figure out how many samples in the target rate we want
clip_duration_s = (indices[-1] - indices[0] + 1) / orig_rate
needed_samples = int(round(clip_duration_s * target_rate))

# Trim or pad as needed
clip_all = clip_all[0, :needed_samples]

# Convert to torch
clip_tensor = torch.from_numpy(clip_all)
if convert_to_melspec:
clip_tensor = waveform2melspec(clip_tensor.float()[None, :], 16000, 128, 204)[0]
clips.append(clip_tensor)

return torch.stack(clips), metadata


def get_clip_indices(sampling_rate, total_samples, num_clips, clip_duration_sec):
clip_samples = int(sampling_rate * clip_duration_sec)
assert clip_samples <= total_samples, \
"Requested clip duration exceeds total samples."

if num_clips == 1:
return [np.arange(0, clip_samples)]

# If total length can accommodate all clips without overlap, space them out evenly
if num_clips * clip_samples <= total_samples:
spacing = total_samples // num_clips
else:
# Overlap: distribute clips so first starts at 0 and last ends at total_samples - clip_samples
spacing = (total_samples - clip_samples) // (num_clips - 1)

start_indices = [i * spacing for i in range(num_clips)]
return [np.arange(start, start + clip_samples) for start in start_indices]


def decode_audio_samples(data: bytes, num_clips: int = 1, clip_duration: int = 1, target_rate: int = 16000, convert_to_melspec: bool = False):

byte_stream = io.BytesIO(data)

with av.open(byte_stream) as input_container:
sample_count = input_container.streams.audio[0].duration
sampling_rate = input_container.streams.audio[0].rate

if num_clips == -1:
num_clips = 1
clip_indices = [[0, sample_count - 1]]
else:
clip_indices = get_clip_indices(sampling_rate, sample_count, num_clips, clip_duration)

audio_tensor, metadata = get_audio_batch(byte_stream, clip_indices, target_rate, convert_to_melspec)

return None, audio_tensor, metadata
Loading
Loading