Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/_core/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
std::string ResizeTransform::getFilterGraphCpu() const {
return "scale=" + std::to_string(outputDims_.width) + ":" +
std::to_string(outputDims_.height) +
":sws_flags=" + toFilterGraphInterpolation(interpolationMode_);
":flags=" + toFilterGraphInterpolation(interpolationMode_);
Copy link
Contributor Author

@scotts scotts Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the FFmpeg docs:

Libavfilter will automatically insert scale filters where format conversion is required. It is possible to specify swscale flags for those automatically inserted scalers by prepending sws_flags=flags; to the filtergraph description.

Whereas flags is the specific parameter to scale. They end up being semantically equivalent, but it's more clear to use the scale option here.

}

std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
Expand Down
22 changes: 21 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import numbers
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
from typing import Any, List, Literal, Optional, Tuple, Union

import torch
from torch import device as torch_device, Tensor
Expand Down Expand Up @@ -103,6 +103,7 @@ def __init__(
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
transforms: List[Any] = [], # TRANSFORMS TODO: what is the user-facing type?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion point 1: If we accept TorchVision transforms, and we want to lazily load TorchVision, what type do we advertise here? We can easily explain that we accept a TorchVision transform in the docs, but what should we put in the type annotation?

Copy link
Contributor

@NicolasHug NicolasHug Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should probably be either Any or nn.Module, which is the base class of all torchvision v2 transforms, and something users are familiar with since this is the core building block of any pytorch model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that solves the problem nicely: it can definitely be nn.Module.

seek_mode: Literal["exact", "approximate"] = "exact",
custom_frame_mappings: Optional[
Union[str, bytes, io.RawIOBase, io.BufferedReader]
Expand Down Expand Up @@ -148,13 +149,16 @@ def __init__(

device_variant = _get_cuda_backend()

transform_specs = make_transform_specs(transforms)

core.add_video_stream(
self._decoder,
stream_index=stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
device_variant=device_variant,
transform_specs=transform_specs,
custom_frame_mappings=custom_frame_mappings_data,
)

Expand Down Expand Up @@ -432,6 +436,22 @@ def _get_and_validate_stream_metadata(
)


def make_transform_specs(transforms: List[Any]) -> str:
from torchvision.transforms import v2

transform_specs = []
for transform in transforms:
if isinstance(transform, v2.Resize):
if len(transform.size) != 2:
raise ValueError(
f"Resize transform must have a (height, width) pair for the size, got {transform.size}."
)
transform_specs.append(f"resize, {transform.size[0]}, {transform.size[1]}")
else:
raise ValueError(f"Unsupported transform {transform}.")
return ";".join(transform_specs)

Copy link
Contributor Author

@scotts scotts Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion point 2: This is what we'll have to do with TorchVision transforms at the moment. We'll need special handling for each transform, looking into its internals to get what we need and enforce decoder-native limitations.

In the future, we can change TorchVision transforms to have an API so that we can get what we need in a generic way. But for now, we'll need to do something like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still undecided on whether we should accept TV transforms or not (ironic, I know), but I think this is totally OK.

And I think we'll need that level of coupling anyway, even if we were to write our own TC transforms. Echoing what you wrote:

If we were to [...] create ]TorchCodec-specific user specification API, we'd want to make sure that its semantics match that of TorchVision. That is, if we had torchcodec.transforms.Resize(height=x, width=y), we'd want to make sure its semantics matched torchvision.transforms.v2.Resize(size=(x,y)). In that specific example, we'd want to make sure that both default to bilinear interpolation. Extrapolating that specific example across all transforms we want to support, we'd basically be creating mirror version of what TorchVision has. That seems silly, since it's more for users to understand and more for us to maintain.

Basically, that coupling between TC and TV will have to exist either in the code (as in this PR), or in our heads as API designers.


Side note, slightly related: if we're going to have our own TC transforms, I think we'll want their API to exactly match (or be a strict subset of) the TV transforms. E.g. we'd have torchcodec.transforms.Resize(size=...) instead of torchcodec.transforms.Resize(height=..., width=...) ?

Copy link
Contributor Author

@scotts scotts Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug, I came to same conclusion as:

Side note, slightly related: if we're going to have our own TC transforms, I think we'll want their API to exactly match (or be a strict subset of) the TV transforms. E.g. we'd have torchcodec.transforms.Resize(size=...) instead of torchcodec.transforms.Resize(height=..., width=...) ?

At which point, I don't think we've really gained anything by having them separate. And users will probably also start asking, hey, can you just accept the TorchVision ones? I also just realized a new counter-point, which I'll put up in the summary as counter point 3.


def _read_custom_frame_mappings(
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
) -> tuple[Tensor, Tensor, Tensor]:
Expand Down
9 changes: 9 additions & 0 deletions test/generate_reference_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ def generate_nasa_13013_references():
NASA_VIDEO, frame_index=frame, stream_index=3, filters=crop_filter
)

frames = [17, 230, 389]
# Note that the resize algorithm passed to flags is exposed to users,
# but bilinear is the default we use.
resize_filter = "scale=240:135:flags=bilinear"
for frame in frames:
generate_frame_by_index(
NASA_VIDEO, frame_index=frame, stream_index=3, filters=resize_filter
)


def generate_h265_video_references():
# This video was generated by running the following:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
64 changes: 62 additions & 2 deletions test/test_transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,75 @@
get_json_metadata,
get_next_frame,
)
from torchcodec.decoders import VideoDecoder

from torchvision.transforms import v2

from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda
from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda, psnr

torch._dynamo.config.capture_dynamic_output_shape_ops = True


class TestVideoDecoderTransformOps:
class TestPublicVideoDecoderTransformOps:
@pytest.mark.parametrize(
"height_scaling_factor, width_scaling_factor",
((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0)),
)
def test_resize_torchvision(self, height_scaling_factor, width_scaling_factor):
height = int(NASA_VIDEO.get_height() * height_scaling_factor)
width = int(NASA_VIDEO.get_width() * width_scaling_factor)

decoder_resize = VideoDecoder(
NASA_VIDEO.path, transforms=[v2.Resize(size=(height, width))]
)
decoder_full = VideoDecoder(NASA_VIDEO.path)
for frame_index in [0, 10, 17, 100, 230, 389]:
expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width)
frame_resize = decoder_resize[frame_index]

frame_full = decoder_full[frame_index]
frame_tv = v2.functional.resize(frame_full, size=(height, width))

assert frame_resize.shape == expected_shape
assert frame_tv.shape == expected_shape

# Copied from PR #992; not sure if it's the best way to check
assert psnr(frame_resize, frame_tv) > 25

def test_resize_ffmpeg(self):
height = 135
width = 240
expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width)
resize_filtergraph = f"scale={width}:{height}:flags=bilinear"
decoder_resize = VideoDecoder(
NASA_VIDEO.path, transforms=[v2.Resize(size=(height, width))]
)
for frame_index in [17, 230, 389]:
frame_resize = decoder_resize[frame_index]
frame_ref = NASA_VIDEO.get_frame_data_by_index(
frame_index, filters=resize_filtergraph
)

assert frame_resize.shape == expected_shape
assert frame_ref.shape == expected_shape
assert_frames_equal(frame_resize, frame_ref)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently fails. Still investigating.


def test_resize_fails(self):
with pytest.raises(
ValueError,
match=r"must have a \(height, width\) pair for the size",
):
VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))])

def test_transform_fails(self):
with pytest.raises(
ValueError,
match="Unsupported transform",
):
VideoDecoder(NASA_VIDEO.path, transforms=[v2.RandomHorizontalFlip(p=1.0)])


class TestCoreVideoDecoderTransformOps:
# We choose arbitrary values for width and height scaling to get better
# test coverage. Some pairs upscale the image while others downscale it.
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def empty_chw_tensor(self) -> torch.Tensor:
[0, self.num_color_channels, self.height, self.width], dtype=torch.uint8
)

def get_width(self, *, stream_index: Optional[int]) -> int:
def get_width(self, *, stream_index: Optional[int] = None) -> int:
if stream_index is None:
stream_index = self.default_stream_index

Expand Down
Loading