Skip to content

Commit b2a90f9

Browse files
mthrokfacebook-github-bot
authored andcommitted
Add YUV444P support to StreamReader (#2516)
Summary: This commit add support for `"yuv444p"` type as output format of StreamReader. Pull Request resolved: #2516 Reviewed By: hwangjeff Differential Revision: D37659715 Pulled By: mthrok fbshipit-source-id: eae9b5590d8f138a6ebf3808c08adfe068f11a2b
1 parent 10ac6d2 commit b2a90f9

File tree

4 files changed

+127
-24
lines changed

4 files changed

+127
-24
lines changed

test/torchaudio_unittest/common_utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from .data_utils import get_asset_path, get_sinusoid, get_spectrogram, get_whitenoise
2121
from .func_utils import torch_script
22-
from .image_utils import get_image, save_image
22+
from .image_utils import get_image, rgb_to_gray, rgb_to_yuv_ccir, save_image
2323
from .parameterized_utils import load_params, nested_params
2424
from .wav_utils import get_wav_data, load_wav, normalize_wav, save_wav
2525

@@ -55,4 +55,6 @@
5555
"torch_script",
5656
"save_image",
5757
"get_image",
58+
"rgb_to_gray",
59+
"rgb_to_yuv_ccir",
5860
]

test/torchaudio_unittest/common_utils/image_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,46 @@ def get_image(width, height, grayscale=False):
2727
img = torch.arange(numel, dtype=torch.int64) % 256
2828
img = img.reshape(channels, height, width).to(torch.uint8)
2929
return img
30+
31+
32+
def rgb_to_yuv_ccir(img):
33+
"""rgb to yuv conversion ported from ffmpeg
34+
35+
The input image is expected to be (..., channel, height, width).
36+
"""
37+
assert img.dtype == torch.uint8
38+
img = img.to(torch.float32)
39+
40+
r, g, b = torch.split(img, 1, dim=-3)
41+
42+
# https://github.com/FFmpeg/FFmpeg/blob/870bfe16a12bf09dca3a4ae27ef6f81a2de80c40/libavutil/colorspace.h#L98
43+
y = 263 * r + 516 * g + 100 * b + 512 + 16384
44+
y /= 1024
45+
46+
# https://github.com/FFmpeg/FFmpeg/blob/870bfe16a12bf09dca3a4ae27ef6f81a2de80c40/libavutil/colorspace.h#L102
47+
# shift == 0
48+
u = -152 * r - 298 * g + 450 * b + 512 - 1
49+
u /= 1024
50+
u += 128
51+
52+
# https://github.com/FFmpeg/FFmpeg/blob/870bfe16a12bf09dca3a4ae27ef6f81a2de80c40/libavutil/colorspace.h#L106
53+
# shift == 0
54+
v = 450 * r - 377 * g - 73 * b + 512 - 1
55+
v /= 1024
56+
v += 128
57+
58+
return torch.cat([y, u, v], -3).to(torch.uint8)
59+
60+
61+
def rgb_to_gray(img):
62+
"""rgb to gray conversion
63+
64+
The input image is expected to be (..., channel, height, width).
65+
"""
66+
assert img.dtype == torch.uint8
67+
img = img.to(torch.float32)
68+
69+
r, g, b = torch.split(img, 1, dim=-3)
70+
71+
gray = 0.299 * r + 0.587 * g + 0.114 * b
72+
return gray.to(torch.uint8)

test/torchaudio_unittest/io/stream_reader_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
get_wav_data,
88
is_ffmpeg_available,
99
nested_params,
10+
rgb_to_gray,
11+
rgb_to_yuv_ccir,
1012
save_image,
1113
save_wav,
1214
skipIfNoFFmpeg,
@@ -614,3 +616,29 @@ def test_png_effect(self, filter_desc, index):
614616
print("expected", expected)
615617
print("output", output)
616618
self.assertEqual(expected, output)
619+
620+
def test_png_yuv_read_out(self):
621+
"""Providing format prpoerly change the color space"""
622+
rgb = torch.empty(1, 3, 256, 256, dtype=torch.uint8)
623+
rgb[0, 0] = torch.arange(256, dtype=torch.uint8).reshape([1, -1])
624+
rgb[0, 1] = torch.arange(256, dtype=torch.uint8).reshape([-1, 1])
625+
for i in range(256):
626+
rgb[0, 2] = i
627+
path = self.get_temp_path(f"ref_{i}.png")
628+
save_image(path, rgb[0], mode="RGB")
629+
630+
yuv = rgb_to_yuv_ccir(rgb)
631+
bgr = rgb[:, [2, 1, 0], :, :]
632+
gray = rgb_to_gray(rgb)
633+
634+
s = StreamReader(path)
635+
s.add_basic_video_stream(frames_per_chunk=-1, format="yuv444p")
636+
s.add_basic_video_stream(frames_per_chunk=-1, format="rgb24")
637+
s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")
638+
s.add_basic_video_stream(frames_per_chunk=-1, format="gray8")
639+
s.process_all_packets()
640+
output_yuv, output_rgb, output_bgr, output_gray = s.pop_chunks()
641+
self.assertEqual(yuv, output_yuv, atol=1, rtol=0)
642+
self.assertEqual(rgb, output_rgb, atol=0, rtol=0)
643+
self.assertEqual(bgr, output_bgr, atol=0, rtol=0)
644+
self.assertEqual(gray, output_gray, atol=1, rtol=0)

torchaudio/csrc/ffmpeg/buffer.cpp

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,56 @@ void AudioBuffer::push_frame(AVFrame* frame) {
171171
// Modifiers - Push Video
172172
////////////////////////////////////////////////////////////////////////////////
173173
namespace {
174+
torch::Tensor convert_interlaced_video(AVFrame* pFrame) {
175+
int width = pFrame->width;
176+
int height = pFrame->height;
177+
uint8_t* buf = pFrame->data[0];
178+
int linesize = pFrame->linesize[0];
179+
int channel = av_pix_fmt_desc_get(static_cast<AVPixelFormat>(pFrame->format))
180+
->nb_components;
181+
182+
auto options = torch::TensorOptions()
183+
.dtype(torch::kUInt8)
184+
.layout(torch::kStrided)
185+
.device(torch::kCPU);
186+
187+
torch::Tensor frame = torch::empty({1, height, width, channel}, options);
188+
auto ptr = frame.data_ptr<uint8_t>();
189+
int stride = width * channel;
190+
for (int i = 0; i < height; ++i) {
191+
memcpy(ptr, buf, stride);
192+
buf += linesize;
193+
ptr += stride;
194+
}
195+
return frame.permute({0, 3, 1, 2});
196+
}
197+
198+
torch::Tensor convert_planar_video(AVFrame* pFrame) {
199+
int width = pFrame->width;
200+
int height = pFrame->height;
201+
int num_planes =
202+
av_pix_fmt_count_planes(static_cast<AVPixelFormat>(pFrame->format));
203+
204+
auto options = torch::TensorOptions()
205+
.dtype(torch::kUInt8)
206+
.layout(torch::kStrided)
207+
.device(torch::kCPU);
208+
209+
torch::Tensor frame = torch::empty({1, num_planes, height, width}, options);
210+
for (int i = 0; i < num_planes; ++i) {
211+
torch::Tensor plane = frame.index({0, i});
212+
uint8_t* tgt = plane.data_ptr<uint8_t>();
213+
uint8_t* src = pFrame->data[i];
214+
int linesize = pFrame->linesize[i];
215+
for (int h = 0; h < height; ++h) {
216+
memcpy(tgt, src, width);
217+
tgt += width;
218+
src += linesize;
219+
}
220+
}
221+
return frame;
222+
}
223+
174224
torch::Tensor convert_yuv420p(AVFrame* pFrame) {
175225
int width = pFrame->width;
176226
int height = pFrame->height;
@@ -316,26 +366,17 @@ torch::Tensor convert_image_tensor(
316366
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
317367
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
318368
AVPixelFormat format = static_cast<AVPixelFormat>(pFrame->format);
319-
int width = pFrame->width;
320-
int height = pFrame->height;
321-
uint8_t* buf = pFrame->data[0];
322-
int linesize = pFrame->linesize[0];
323-
324-
int channel;
325369
switch (format) {
326370
case AV_PIX_FMT_RGB24:
327371
case AV_PIX_FMT_BGR24:
328-
channel = 3;
329-
break;
330372
case AV_PIX_FMT_ARGB:
331373
case AV_PIX_FMT_RGBA:
332374
case AV_PIX_FMT_ABGR:
333375
case AV_PIX_FMT_BGRA:
334-
channel = 4;
335-
break;
336376
case AV_PIX_FMT_GRAY8:
337-
channel = 1;
338-
break;
377+
return convert_interlaced_video(pFrame);
378+
case AV_PIX_FMT_YUV444P:
379+
return convert_planar_video(pFrame);
339380
case AV_PIX_FMT_YUV420P:
340381
return convert_yuv420p(pFrame);
341382
case AV_PIX_FMT_NV12:
@@ -368,17 +409,6 @@ torch::Tensor convert_image_tensor(
368409
"Unexpected video format: " +
369410
std::string(av_get_pix_fmt_name(format)));
370411
}
371-
372-
torch::Tensor t;
373-
t = torch::empty({1, height, width, channel}, torch::kUInt8);
374-
auto ptr = t.data_ptr<uint8_t>();
375-
int stride = width * channel;
376-
for (int i = 0; i < height; ++i) {
377-
memcpy(ptr, buf, stride);
378-
buf += linesize;
379-
ptr += stride;
380-
}
381-
return t.permute({0, 3, 1, 2});
382412
}
383413
} // namespace
384414

0 commit comments

Comments
 (0)