Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dd24dfa
Decoder-native resize public implementation
scotts Oct 27, 2025
3a2df84
Lint
scotts Oct 27, 2025
5344ab4
Merge branch 'main' of github.com:pytorch/torchcodec into transform_api
scotts Nov 6, 2025
98cf81b
Implement decoder native transforms API
scotts Nov 7, 2025
65c4ad7
Correct merge
scotts Nov 7, 2025
f300c70
Actually add new file
scotts Nov 7, 2025
2c3b7f0
Lint
scotts Nov 7, 2025
80e84b5
Better assert
scotts Nov 7, 2025
5ac60d8
Better comment
scotts Nov 7, 2025
531b40f
Top level transforms import
scotts Nov 7, 2025
cc333ac
Add the init file. Sigh.
scotts Nov 7, 2025
238a8ff
Linter now needs torchvision in the environment
scotts Nov 7, 2025
55d362c
Avoid missing import errors
scotts Nov 7, 2025
0d2492e
Better names, better docs
scotts Nov 8, 2025
a2da767
More testing, docstring editing
scotts Nov 10, 2025
2cd3f65
Changes
scotts Nov 11, 2025
4ff0186
Reference docs
scotts Nov 12, 2025
0f9eb62
Better docs
scotts Nov 12, 2025
8081298
Make make params private
scotts Nov 12, 2025
39ed9ac
Links to TorchVision.
scotts Nov 12, 2025
6e6815c
Rename conversion function
scotts Nov 12, 2025
363e688
Add no-torchvision job
scotts Nov 12, 2025
463674d
On second thought, let's not
scotts Nov 12, 2025
c20914c
Lists are not covariant?
scotts Nov 12, 2025
254641a
Just use an explicit type
scotts Nov 12, 2025
9b4186a
Pull tv2 inspection logic into decoder transform
scotts Nov 13, 2025
105c77f
Update conversion arg comment
scotts Nov 13, 2025
70b5976
Better importing, better docs
scotts Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/_core/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
std::string ResizeTransform::getFilterGraphCpu() const {
return "scale=" + std::to_string(outputDims_.width) + ":" +
std::to_string(outputDims_.height) +
":sws_flags=" + toFilterGraphInterpolation(interpolationMode_);
":flags=" + toFilterGraphInterpolation(interpolationMode_);
}

std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
Expand Down
22 changes: 21 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import numbers
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
from typing import Any, List, Literal, Optional, Tuple, Union

import torch
from torch import device as torch_device, Tensor
Expand Down Expand Up @@ -103,6 +103,7 @@ def __init__(
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
transforms: List[Any] = [], # TRANSFORMS TODO: what is the user-facing type?
seek_mode: Literal["exact", "approximate"] = "exact",
custom_frame_mappings: Optional[
Union[str, bytes, io.RawIOBase, io.BufferedReader]
Expand Down Expand Up @@ -148,13 +149,16 @@ def __init__(

device_variant = _get_cuda_backend()

transform_specs = make_transform_specs(transforms)

core.add_video_stream(
self._decoder,
stream_index=stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
device_variant=device_variant,
transform_specs=transform_specs,
custom_frame_mappings=custom_frame_mappings_data,
)

Expand Down Expand Up @@ -431,6 +435,22 @@ def _get_and_validate_stream_metadata(
num_frames,
)

def make_transform_specs(transforms: List[Any]) -> str:
from torchvision.transforms import v2

transform_specs = []
for transform in transforms:
if isinstance(transform, v2.Resize):
if len(transform.size) != 2:
raise ValueError(
f"Resize transform must have a (height, width) pair for the size, got {transform.size}."
)
transform_specs.append(f"resize, {transform.size[0]}, {transform.size[1]}")
else:
raise ValueError(
f"Unsupported transform {transform}."
)
return ";".join(transform_specs)

def _read_custom_frame_mappings(
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
Expand Down
10 changes: 10 additions & 0 deletions test/generate_reference_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ def generate_nasa_13013_references():
NASA_VIDEO, frame_index=frame, stream_index=3, filters=crop_filter
)

frames = [17, 230, 389]
# Note that the resize algorithm passed to flags is exposed to users,
# but bilinear is the default we use.
resize_filter = "scale=240:135:flags=bilinear"
for frame in frames:
generate_frame_by_index(
NASA_VIDEO, frame_index=frame, stream_index=3, filters=resize_filter
)



def generate_h265_video_references():
# This video was generated by running the following:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
58 changes: 56 additions & 2 deletions test/test_transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,69 @@
get_json_metadata,
get_next_frame,
)
from torchcodec.decoders import VideoDecoder

from torchvision.transforms import v2

from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda
from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda, psnr

torch._dynamo.config.capture_dynamic_output_shape_ops = True


class TestVideoDecoderTransformOps:
class TestPublicVideoDecoderTransformOps:
@pytest.mark.parametrize(
"height_scaling_factor, width_scaling_factor",
((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0)),
)
def test_resize_torchvision(self, height_scaling_factor, width_scaling_factor):
height = int(NASA_VIDEO.get_height() * height_scaling_factor)
width = int(NASA_VIDEO.get_width() * width_scaling_factor)

decoder_resize = VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(height, width))])
decoder_full = VideoDecoder(NASA_VIDEO.path)
for frame_index in [0, 10, 17, 100, 230, 389]:
expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width)
frame_resize = decoder_resize[frame_index]

frame_full = decoder_full[frame_index]
frame_tv = v2.functional.resize(frame_full, size=(height, width))

assert frame_resize.shape == expected_shape
assert frame_tv.shape == expected_shape

# Copied from PR #992; not sure if it's the best way to check
assert psnr(frame_resize, frame_tv) > 25

def test_resize_ffmpeg(self):
height = 135
width = 240
expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width)
resize_filtergraph = f"scale={width}:{height}:flags=bilinear"
decoder_resize = VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(height, width))])
for frame_index in [17, 230, 389]:
frame_resize = decoder_resize[frame_index]
frame_ref = NASA_VIDEO.get_frame_data_by_index(frame_index, filters=resize_filtergraph)

assert frame_resize.shape == expected_shape
assert frame_ref.shape == expected_shape
assert_frames_equal(frame_resize, frame_ref)


def test_resize_fails(self):
with pytest.raises(
ValueError,
match=r"must have a \(height, width\) pair for the size",
):
VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))])

def test_transform_fails(self):
with pytest.raises(
ValueError,
match="Unsupported transform",
):
VideoDecoder(NASA_VIDEO.path, transforms=[v2.RandomHorizontalFlip(p=1.0)])

class TestCoreVideoDecoderTransformOps:
# We choose arbitrary values for width and height scaling to get better
# test coverage. Some pairs upscale the image while others downscale it.
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def empty_chw_tensor(self) -> torch.Tensor:
[0, self.num_color_channels, self.height, self.width], dtype=torch.uint8
)

def get_width(self, *, stream_index: Optional[int]) -> int:
def get_width(self, *, stream_index: Optional[int] = None) -> int:
if stream_index is None:
stream_index = self.default_stream_index

Expand Down
Loading