Skip to content

Commit 288b99c

Browse files
committed
Move CpuDeviceInterface to .cpp files where they belong
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent 533ba4e commit 288b99c

File tree

3 files changed

+337
-343
lines changed

3 files changed

+337
-343
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

+336
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
#include "src/torchcodec/_core/CpuDeviceInterface.h"
88

9+
extern "C" {
10+
#include <libavfilter/buffersink.h>
11+
#include <libavfilter/buffersrc.h>
12+
}
13+
914
namespace facebook::torchcodec {
1015
namespace {
1116

@@ -17,6 +22,20 @@ bool g_cpu = registerDeviceInterface(
1722

1823
} // namespace
1924

25+
bool CpuDeviceInterface::DecodedFrameContext::operator==(
26+
const CpuDeviceInterface::DecodedFrameContext& other) {
27+
return decodedWidth == other.decodedWidth &&
28+
decodedHeight == other.decodedHeight &&
29+
decodedFormat == other.decodedFormat &&
30+
expectedWidth == other.expectedWidth &&
31+
expectedHeight == other.expectedHeight;
32+
}
33+
34+
bool CpuDeviceInterface::DecodedFrameContext::operator!=(
35+
const CpuDeviceInterface::DecodedFrameContext& other) {
36+
return !(*this == other);
37+
}
38+
2039
CpuDeviceInterface::CpuDeviceInterface(
2140
const torch::Device& device,
2241
const AVRational& timeBase)
@@ -26,4 +45,321 @@ CpuDeviceInterface::CpuDeviceInterface(
2645
}
2746
}
2847

48+
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
49+
// Callers may pass a pre-allocated tensor, where the output.data tensor will
50+
// be stored. This parameter is honored in any case, but it only leads to a
51+
// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
52+
// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
53+
// found a way to do that with filtegraph.
54+
// TODO: Figure out whether that's possible!
55+
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
56+
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
57+
void CpuDeviceInterface::convertAVFrameToFrameOutput(
58+
const VideoStreamOptions& videoStreamOptions,
59+
UniqueAVFrame& avFrame,
60+
FrameOutput& frameOutput,
61+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
62+
auto frameDims =
63+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
64+
int expectedOutputHeight = frameDims.height;
65+
int expectedOutputWidth = frameDims.width;
66+
67+
if (preAllocatedOutputTensor.has_value()) {
68+
auto shape = preAllocatedOutputTensor.value().sizes();
69+
TORCH_CHECK(
70+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
71+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
72+
"Expected pre-allocated tensor of shape ",
73+
expectedOutputHeight,
74+
"x",
75+
expectedOutputWidth,
76+
"x3, got ",
77+
shape);
78+
}
79+
80+
torch::Tensor outputTensor;
81+
// We need to compare the current frame context with our previous frame
82+
// context. If they are different, then we need to re-create our colorspace
83+
// conversion objects. We create our colorspace conversion objects late so
84+
// that we don't have to depend on the unreliable metadata in the header.
85+
// And we sometimes re-create them because it's possible for frame
86+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
87+
// conversion objects as much as possible for performance reasons.
88+
enum AVPixelFormat frameFormat =
89+
static_cast<enum AVPixelFormat>(avFrame->format);
90+
auto frameContext = DecodedFrameContext{
91+
avFrame->width,
92+
avFrame->height,
93+
frameFormat,
94+
avFrame->sample_aspect_ratio,
95+
expectedOutputWidth,
96+
expectedOutputHeight};
97+
98+
// By default, we want to use swscale for color conversion because it is
99+
// faster. However, it has width requirements, so we may need to fall back
100+
// to filtergraph. We also need to respect what was requested from the
101+
// options; we respect the options unconditionally, so it's possible for
102+
// swscale's width requirements to be violated. We don't expose the ability to
103+
// choose color conversion library publicly; we only use this ability
104+
// internally.
105+
106+
// swscale requires widths to be multiples of 32:
107+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
108+
// so we fall back to filtergraph if the width is not a multiple of 32.
109+
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
110+
? ColorConversionLibrary::SWSCALE
111+
: ColorConversionLibrary::FILTERGRAPH;
112+
113+
ColorConversionLibrary colorConversionLibrary =
114+
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
115+
116+
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
117+
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
118+
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
119+
120+
if (!swsContext_ || prevFrameContext_ != frameContext) {
121+
createSwsContext(frameContext, avFrame->colorspace);
122+
prevFrameContext_ = frameContext;
123+
}
124+
int resultHeight =
125+
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
126+
// If this check failed, it would mean that the frame wasn't reshaped to
127+
// the expected height.
128+
// TODO: Can we do the same check for width?
129+
TORCH_CHECK(
130+
resultHeight == expectedOutputHeight,
131+
"resultHeight != expectedOutputHeight: ",
132+
resultHeight,
133+
" != ",
134+
expectedOutputHeight);
135+
136+
frameOutput.data = outputTensor;
137+
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
138+
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
139+
createFilterGraph(frameContext, videoStreamOptions);
140+
prevFrameContext_ = frameContext;
141+
}
142+
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
143+
144+
// Similarly to above, if this check fails it means the frame wasn't
145+
// reshaped to its expected dimensions by filtergraph.
146+
auto shape = outputTensor.sizes();
147+
TORCH_CHECK(
148+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
149+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
150+
"Expected output tensor of shape ",
151+
expectedOutputHeight,
152+
"x",
153+
expectedOutputWidth,
154+
"x3, got ",
155+
shape);
156+
157+
if (preAllocatedOutputTensor.has_value()) {
158+
// We have already validated that preAllocatedOutputTensor and
159+
// outputTensor have the same shape.
160+
preAllocatedOutputTensor.value().copy_(outputTensor);
161+
frameOutput.data = preAllocatedOutputTensor.value();
162+
} else {
163+
frameOutput.data = outputTensor;
164+
}
165+
} else {
166+
throw std::runtime_error(
167+
"Invalid color conversion library: " +
168+
std::to_string(static_cast<int>(colorConversionLibrary)));
169+
}
170+
}
171+
172+
int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
173+
const UniqueAVFrame& avFrame,
174+
torch::Tensor& outputTensor) {
175+
uint8_t* pointers[4] = {
176+
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
177+
int expectedOutputWidth = outputTensor.sizes()[1];
178+
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
179+
int resultHeight = sws_scale(
180+
swsContext_.get(),
181+
avFrame->data,
182+
avFrame->linesize,
183+
0,
184+
avFrame->height,
185+
pointers,
186+
linesizes);
187+
return resultHeight;
188+
}
189+
190+
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
191+
const UniqueAVFrame& avFrame) {
192+
int status = av_buffersrc_write_frame(
193+
filterGraphContext_.sourceContext, avFrame.get());
194+
if (status < AVSUCCESS) {
195+
throw std::runtime_error("Failed to add frame to buffer source context");
196+
}
197+
198+
UniqueAVFrame filteredAVFrame(av_frame_alloc());
199+
status = av_buffersink_get_frame(
200+
filterGraphContext_.sinkContext, filteredAVFrame.get());
201+
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
202+
203+
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
204+
int height = frameDims.height;
205+
int width = frameDims.width;
206+
std::vector<int64_t> shape = {height, width, 3};
207+
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
208+
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
209+
auto deleter = [filteredAVFramePtr](void*) {
210+
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
211+
};
212+
return torch::from_blob(
213+
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
214+
}
215+
216+
void CpuDeviceInterface::createFilterGraph(
217+
const DecodedFrameContext& frameContext,
218+
const VideoStreamOptions& videoStreamOptions) {
219+
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
220+
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);
221+
222+
if (videoStreamOptions.ffmpegThreadCount.has_value()) {
223+
filterGraphContext_.filterGraph->nb_threads =
224+
videoStreamOptions.ffmpegThreadCount.value();
225+
}
226+
227+
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
228+
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
229+
230+
std::stringstream filterArgs;
231+
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
232+
<< frameContext.decodedHeight;
233+
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
234+
filterArgs << ":time_base=" << timeBase_.num << "/" << timeBase_.den;
235+
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
236+
<< frameContext.decodedAspectRatio.den;
237+
238+
int status = avfilter_graph_create_filter(
239+
&filterGraphContext_.sourceContext,
240+
buffersrc,
241+
"in",
242+
filterArgs.str().c_str(),
243+
nullptr,
244+
filterGraphContext_.filterGraph.get());
245+
if (status < 0) {
246+
throw std::runtime_error(
247+
std::string("Failed to create filter graph: ") + filterArgs.str() +
248+
": " + getFFMPEGErrorStringFromErrorCode(status));
249+
}
250+
251+
status = avfilter_graph_create_filter(
252+
&filterGraphContext_.sinkContext,
253+
buffersink,
254+
"out",
255+
nullptr,
256+
nullptr,
257+
filterGraphContext_.filterGraph.get());
258+
if (status < 0) {
259+
throw std::runtime_error(
260+
"Failed to create filter graph: " +
261+
getFFMPEGErrorStringFromErrorCode(status));
262+
}
263+
264+
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
265+
266+
status = av_opt_set_int_list(
267+
filterGraphContext_.sinkContext,
268+
"pix_fmts",
269+
pix_fmts,
270+
AV_PIX_FMT_NONE,
271+
AV_OPT_SEARCH_CHILDREN);
272+
if (status < 0) {
273+
throw std::runtime_error(
274+
"Failed to set output pixel formats: " +
275+
getFFMPEGErrorStringFromErrorCode(status));
276+
}
277+
278+
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
279+
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
280+
281+
outputs->name = av_strdup("in");
282+
outputs->filter_ctx = filterGraphContext_.sourceContext;
283+
outputs->pad_idx = 0;
284+
outputs->next = nullptr;
285+
inputs->name = av_strdup("out");
286+
inputs->filter_ctx = filterGraphContext_.sinkContext;
287+
inputs->pad_idx = 0;
288+
inputs->next = nullptr;
289+
290+
std::stringstream description;
291+
description << "scale=" << frameContext.expectedWidth << ":"
292+
<< frameContext.expectedHeight;
293+
description << ":sws_flags=bilinear";
294+
295+
AVFilterInOut* outputsTmp = outputs.release();
296+
AVFilterInOut* inputsTmp = inputs.release();
297+
status = avfilter_graph_parse_ptr(
298+
filterGraphContext_.filterGraph.get(),
299+
description.str().c_str(),
300+
&inputsTmp,
301+
&outputsTmp,
302+
nullptr);
303+
outputs.reset(outputsTmp);
304+
inputs.reset(inputsTmp);
305+
if (status < 0) {
306+
throw std::runtime_error(
307+
"Failed to parse filter description: " +
308+
getFFMPEGErrorStringFromErrorCode(status));
309+
}
310+
311+
status =
312+
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
313+
if (status < 0) {
314+
throw std::runtime_error(
315+
"Failed to configure filter graph: " +
316+
getFFMPEGErrorStringFromErrorCode(status));
317+
}
318+
}
319+
320+
void CpuDeviceInterface::createSwsContext(
321+
const DecodedFrameContext& frameContext,
322+
const enum AVColorSpace colorspace) {
323+
SwsContext* swsContext = sws_getContext(
324+
frameContext.decodedWidth,
325+
frameContext.decodedHeight,
326+
frameContext.decodedFormat,
327+
frameContext.expectedWidth,
328+
frameContext.expectedHeight,
329+
AV_PIX_FMT_RGB24,
330+
SWS_BILINEAR,
331+
nullptr,
332+
nullptr,
333+
nullptr);
334+
TORCH_CHECK(swsContext, "sws_getContext() returned nullptr");
335+
336+
int* invTable = nullptr;
337+
int* table = nullptr;
338+
int srcRange, dstRange, brightness, contrast, saturation;
339+
int ret = sws_getColorspaceDetails(
340+
swsContext,
341+
&invTable,
342+
&srcRange,
343+
&table,
344+
&dstRange,
345+
&brightness,
346+
&contrast,
347+
&saturation);
348+
TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1");
349+
350+
const int* colorspaceTable = sws_getCoefficients(colorspace);
351+
ret = sws_setColorspaceDetails(
352+
swsContext,
353+
colorspaceTable,
354+
srcRange,
355+
colorspaceTable,
356+
dstRange,
357+
brightness,
358+
contrast,
359+
saturation);
360+
TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1");
361+
362+
swsContext_.reset(swsContext);
363+
}
364+
29365
} // namespace facebook::torchcodec

src/torchcodec/_core/CpuDeviceInterface.h

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include "src/torchcodec/_core/DeviceInterface.h"
10+
#include "src/torchcodec/_core/FFMPEGCommon.h"
1011

1112
namespace facebook::torchcodec {
1213

0 commit comments

Comments
 (0)