Skip to content

Commit 7ca3a8d

Browse files
author
pytorchbot
committed
2025-05-14 nightly release (0cced7c)
1 parent 78add96 commit 7ca3a8d

11 files changed

+575
-514
lines changed

src/torchcodec/_core/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ function(make_torchcodec_libraries
6767
AVIOContextHolder.cpp
6868
AVIOBytesContext.cpp
6969
FFMPEGCommon.cpp
70+
Frame.cpp
7071
DeviceInterface.cpp
72+
CpuDeviceInterface.cpp
7173
SingleStreamDecoder.cpp
7274
# TODO: lib name should probably not be "*_decoder*" now that it also
7375
# contains an encoder
+364
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/_core/CpuDeviceInterface.h"
8+
9+
extern "C" {
10+
#include <libavfilter/buffersink.h>
11+
#include <libavfilter/buffersrc.h>
12+
}
13+
14+
namespace facebook::torchcodec {
15+
namespace {
16+
17+
static bool g_cpu = registerDeviceInterface(
18+
torch::kCPU,
19+
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
20+
21+
} // namespace
22+
23+
bool CpuDeviceInterface::DecodedFrameContext::operator==(
24+
const CpuDeviceInterface::DecodedFrameContext& other) {
25+
return decodedWidth == other.decodedWidth &&
26+
decodedHeight == other.decodedHeight &&
27+
decodedFormat == other.decodedFormat &&
28+
expectedWidth == other.expectedWidth &&
29+
expectedHeight == other.expectedHeight;
30+
}
31+
32+
bool CpuDeviceInterface::DecodedFrameContext::operator!=(
33+
const CpuDeviceInterface::DecodedFrameContext& other) {
34+
return !(*this == other);
35+
}
36+
37+
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
38+
: DeviceInterface(device) {
39+
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
40+
if (device_.type() != torch::kCPU) {
41+
throw std::runtime_error("Unsupported device: " + device_.str());
42+
}
43+
}
44+
45+
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
46+
// Callers may pass a pre-allocated tensor, where the output.data tensor will
47+
// be stored. This parameter is honored in any case, but it only leads to a
48+
// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
49+
// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
50+
// found a way to do that with filtegraph.
51+
// TODO: Figure out whether that's possible!
52+
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
53+
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
54+
void CpuDeviceInterface::convertAVFrameToFrameOutput(
55+
const VideoStreamOptions& videoStreamOptions,
56+
const AVRational& timeBase,
57+
UniqueAVFrame& avFrame,
58+
FrameOutput& frameOutput,
59+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
60+
auto frameDims =
61+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
62+
int expectedOutputHeight = frameDims.height;
63+
int expectedOutputWidth = frameDims.width;
64+
65+
if (preAllocatedOutputTensor.has_value()) {
66+
auto shape = preAllocatedOutputTensor.value().sizes();
67+
TORCH_CHECK(
68+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
69+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
70+
"Expected pre-allocated tensor of shape ",
71+
expectedOutputHeight,
72+
"x",
73+
expectedOutputWidth,
74+
"x3, got ",
75+
shape);
76+
}
77+
78+
torch::Tensor outputTensor;
79+
// We need to compare the current frame context with our previous frame
80+
// context. If they are different, then we need to re-create our colorspace
81+
// conversion objects. We create our colorspace conversion objects late so
82+
// that we don't have to depend on the unreliable metadata in the header.
83+
// And we sometimes re-create them because it's possible for frame
84+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
85+
// conversion objects as much as possible for performance reasons.
86+
enum AVPixelFormat frameFormat =
87+
static_cast<enum AVPixelFormat>(avFrame->format);
88+
auto frameContext = DecodedFrameContext{
89+
avFrame->width,
90+
avFrame->height,
91+
frameFormat,
92+
avFrame->sample_aspect_ratio,
93+
expectedOutputWidth,
94+
expectedOutputHeight};
95+
96+
// By default, we want to use swscale for color conversion because it is
97+
// faster. However, it has width requirements, so we may need to fall back
98+
// to filtergraph. We also need to respect what was requested from the
99+
// options; we respect the options unconditionally, so it's possible for
100+
// swscale's width requirements to be violated. We don't expose the ability to
101+
// choose color conversion library publicly; we only use this ability
102+
// internally.
103+
104+
// swscale requires widths to be multiples of 32:
105+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
106+
// so we fall back to filtergraph if the width is not a multiple of 32.
107+
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
108+
? ColorConversionLibrary::SWSCALE
109+
: ColorConversionLibrary::FILTERGRAPH;
110+
111+
ColorConversionLibrary colorConversionLibrary =
112+
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
113+
114+
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
115+
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
116+
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
117+
118+
if (!swsContext_ || prevFrameContext_ != frameContext) {
119+
createSwsContext(frameContext, avFrame->colorspace);
120+
prevFrameContext_ = frameContext;
121+
}
122+
int resultHeight =
123+
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
124+
// If this check failed, it would mean that the frame wasn't reshaped to
125+
// the expected height.
126+
// TODO: Can we do the same check for width?
127+
TORCH_CHECK(
128+
resultHeight == expectedOutputHeight,
129+
"resultHeight != expectedOutputHeight: ",
130+
resultHeight,
131+
" != ",
132+
expectedOutputHeight);
133+
134+
frameOutput.data = outputTensor;
135+
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
136+
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
137+
createFilterGraph(frameContext, videoStreamOptions, timeBase);
138+
prevFrameContext_ = frameContext;
139+
}
140+
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
141+
142+
// Similarly to above, if this check fails it means the frame wasn't
143+
// reshaped to its expected dimensions by filtergraph.
144+
auto shape = outputTensor.sizes();
145+
TORCH_CHECK(
146+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
147+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
148+
"Expected output tensor of shape ",
149+
expectedOutputHeight,
150+
"x",
151+
expectedOutputWidth,
152+
"x3, got ",
153+
shape);
154+
155+
if (preAllocatedOutputTensor.has_value()) {
156+
// We have already validated that preAllocatedOutputTensor and
157+
// outputTensor have the same shape.
158+
preAllocatedOutputTensor.value().copy_(outputTensor);
159+
frameOutput.data = preAllocatedOutputTensor.value();
160+
} else {
161+
frameOutput.data = outputTensor;
162+
}
163+
} else {
164+
throw std::runtime_error(
165+
"Invalid color conversion library: " +
166+
std::to_string(static_cast<int>(colorConversionLibrary)));
167+
}
168+
}
169+
170+
int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
171+
const UniqueAVFrame& avFrame,
172+
torch::Tensor& outputTensor) {
173+
uint8_t* pointers[4] = {
174+
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
175+
int expectedOutputWidth = outputTensor.sizes()[1];
176+
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
177+
int resultHeight = sws_scale(
178+
swsContext_.get(),
179+
avFrame->data,
180+
avFrame->linesize,
181+
0,
182+
avFrame->height,
183+
pointers,
184+
linesizes);
185+
return resultHeight;
186+
}
187+
188+
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
189+
const UniqueAVFrame& avFrame) {
190+
int status = av_buffersrc_write_frame(
191+
filterGraphContext_.sourceContext, avFrame.get());
192+
if (status < AVSUCCESS) {
193+
throw std::runtime_error("Failed to add frame to buffer source context");
194+
}
195+
196+
UniqueAVFrame filteredAVFrame(av_frame_alloc());
197+
status = av_buffersink_get_frame(
198+
filterGraphContext_.sinkContext, filteredAVFrame.get());
199+
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
200+
201+
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
202+
int height = frameDims.height;
203+
int width = frameDims.width;
204+
std::vector<int64_t> shape = {height, width, 3};
205+
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
206+
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
207+
auto deleter = [filteredAVFramePtr](void*) {
208+
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
209+
};
210+
return torch::from_blob(
211+
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
212+
}
213+
214+
void CpuDeviceInterface::createFilterGraph(
215+
const DecodedFrameContext& frameContext,
216+
const VideoStreamOptions& videoStreamOptions,
217+
const AVRational& timeBase) {
218+
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
219+
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);
220+
221+
if (videoStreamOptions.ffmpegThreadCount.has_value()) {
222+
filterGraphContext_.filterGraph->nb_threads =
223+
videoStreamOptions.ffmpegThreadCount.value();
224+
}
225+
226+
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
227+
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
228+
229+
std::stringstream filterArgs;
230+
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
231+
<< frameContext.decodedHeight;
232+
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
233+
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
234+
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
235+
<< frameContext.decodedAspectRatio.den;
236+
237+
int status = avfilter_graph_create_filter(
238+
&filterGraphContext_.sourceContext,
239+
buffersrc,
240+
"in",
241+
filterArgs.str().c_str(),
242+
nullptr,
243+
filterGraphContext_.filterGraph.get());
244+
if (status < 0) {
245+
throw std::runtime_error(
246+
std::string("Failed to create filter graph: ") + filterArgs.str() +
247+
": " + getFFMPEGErrorStringFromErrorCode(status));
248+
}
249+
250+
status = avfilter_graph_create_filter(
251+
&filterGraphContext_.sinkContext,
252+
buffersink,
253+
"out",
254+
nullptr,
255+
nullptr,
256+
filterGraphContext_.filterGraph.get());
257+
if (status < 0) {
258+
throw std::runtime_error(
259+
"Failed to create filter graph: " +
260+
getFFMPEGErrorStringFromErrorCode(status));
261+
}
262+
263+
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
264+
265+
status = av_opt_set_int_list(
266+
filterGraphContext_.sinkContext,
267+
"pix_fmts",
268+
pix_fmts,
269+
AV_PIX_FMT_NONE,
270+
AV_OPT_SEARCH_CHILDREN);
271+
if (status < 0) {
272+
throw std::runtime_error(
273+
"Failed to set output pixel formats: " +
274+
getFFMPEGErrorStringFromErrorCode(status));
275+
}
276+
277+
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
278+
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
279+
280+
outputs->name = av_strdup("in");
281+
outputs->filter_ctx = filterGraphContext_.sourceContext;
282+
outputs->pad_idx = 0;
283+
outputs->next = nullptr;
284+
inputs->name = av_strdup("out");
285+
inputs->filter_ctx = filterGraphContext_.sinkContext;
286+
inputs->pad_idx = 0;
287+
inputs->next = nullptr;
288+
289+
std::stringstream description;
290+
description << "scale=" << frameContext.expectedWidth << ":"
291+
<< frameContext.expectedHeight;
292+
description << ":sws_flags=bilinear";
293+
294+
AVFilterInOut* outputsTmp = outputs.release();
295+
AVFilterInOut* inputsTmp = inputs.release();
296+
status = avfilter_graph_parse_ptr(
297+
filterGraphContext_.filterGraph.get(),
298+
description.str().c_str(),
299+
&inputsTmp,
300+
&outputsTmp,
301+
nullptr);
302+
outputs.reset(outputsTmp);
303+
inputs.reset(inputsTmp);
304+
if (status < 0) {
305+
throw std::runtime_error(
306+
"Failed to parse filter description: " +
307+
getFFMPEGErrorStringFromErrorCode(status));
308+
}
309+
310+
status =
311+
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
312+
if (status < 0) {
313+
throw std::runtime_error(
314+
"Failed to configure filter graph: " +
315+
getFFMPEGErrorStringFromErrorCode(status));
316+
}
317+
}
318+
319+
void CpuDeviceInterface::createSwsContext(
320+
const DecodedFrameContext& frameContext,
321+
const enum AVColorSpace colorspace) {
322+
SwsContext* swsContext = sws_getContext(
323+
frameContext.decodedWidth,
324+
frameContext.decodedHeight,
325+
frameContext.decodedFormat,
326+
frameContext.expectedWidth,
327+
frameContext.expectedHeight,
328+
AV_PIX_FMT_RGB24,
329+
SWS_BILINEAR,
330+
nullptr,
331+
nullptr,
332+
nullptr);
333+
TORCH_CHECK(swsContext, "sws_getContext() returned nullptr");
334+
335+
int* invTable = nullptr;
336+
int* table = nullptr;
337+
int srcRange, dstRange, brightness, contrast, saturation;
338+
int ret = sws_getColorspaceDetails(
339+
swsContext,
340+
&invTable,
341+
&srcRange,
342+
&table,
343+
&dstRange,
344+
&brightness,
345+
&contrast,
346+
&saturation);
347+
TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1");
348+
349+
const int* colorspaceTable = sws_getCoefficients(colorspace);
350+
ret = sws_setColorspaceDetails(
351+
swsContext,
352+
colorspaceTable,
353+
srcRange,
354+
colorspaceTable,
355+
dstRange,
356+
brightness,
357+
contrast,
358+
saturation);
359+
TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1");
360+
361+
swsContext_.reset(swsContext);
362+
}
363+
364+
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)