6
6
7
7
#include " src/torchcodec/_core/CpuDeviceInterface.h"
8
8
9
+ extern " C" {
10
+ #include < libavfilter/buffersink.h>
11
+ #include < libavfilter/buffersrc.h>
12
+ }
13
+
9
14
namespace facebook ::torchcodec {
10
15
namespace {
11
16
@@ -17,6 +22,20 @@ bool g_cpu = registerDeviceInterface(
17
22
18
23
} // namespace
19
24
25
+ bool CpuDeviceInterface::DecodedFrameContext::operator ==(
26
+ const CpuDeviceInterface::DecodedFrameContext& other) {
27
+ return decodedWidth == other.decodedWidth &&
28
+ decodedHeight == other.decodedHeight &&
29
+ decodedFormat == other.decodedFormat &&
30
+ expectedWidth == other.expectedWidth &&
31
+ expectedHeight == other.expectedHeight ;
32
+ }
33
+
34
+ bool CpuDeviceInterface::DecodedFrameContext::operator !=(
35
+ const CpuDeviceInterface::DecodedFrameContext& other) {
36
+ return !(*this == other);
37
+ }
38
+
20
39
CpuDeviceInterface::CpuDeviceInterface (
21
40
const torch::Device& device,
22
41
const AVRational& timeBase)
@@ -26,4 +45,321 @@ CpuDeviceInterface::CpuDeviceInterface(
26
45
}
27
46
}
28
47
48
+ // Note [preAllocatedOutputTensor with swscale and filtergraph]:
49
+ // Callers may pass a pre-allocated tensor, where the output.data tensor will
50
+ // be stored. This parameter is honored in any case, but it only leads to a
51
+ // speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
52
+ // decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
53
+ // found a way to do that with filtegraph.
54
+ // TODO: Figure out whether that's possible!
55
+ // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
56
+ // `dimension_order` parameter. It's up to callers to re-shape it if needed.
57
+ void CpuDeviceInterface::convertAVFrameToFrameOutput (
58
+ const VideoStreamOptions& videoStreamOptions,
59
+ UniqueAVFrame& avFrame,
60
+ FrameOutput& frameOutput,
61
+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
62
+ auto frameDims =
63
+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
64
+ int expectedOutputHeight = frameDims.height ;
65
+ int expectedOutputWidth = frameDims.width ;
66
+
67
+ if (preAllocatedOutputTensor.has_value ()) {
68
+ auto shape = preAllocatedOutputTensor.value ().sizes ();
69
+ TORCH_CHECK (
70
+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
71
+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
72
+ " Expected pre-allocated tensor of shape " ,
73
+ expectedOutputHeight,
74
+ " x" ,
75
+ expectedOutputWidth,
76
+ " x3, got " ,
77
+ shape);
78
+ }
79
+
80
+ torch::Tensor outputTensor;
81
+ // We need to compare the current frame context with our previous frame
82
+ // context. If they are different, then we need to re-create our colorspace
83
+ // conversion objects. We create our colorspace conversion objects late so
84
+ // that we don't have to depend on the unreliable metadata in the header.
85
+ // And we sometimes re-create them because it's possible for frame
86
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
87
+ // conversion objects as much as possible for performance reasons.
88
+ enum AVPixelFormat frameFormat =
89
+ static_cast <enum AVPixelFormat>(avFrame->format );
90
+ auto frameContext = DecodedFrameContext{
91
+ avFrame->width ,
92
+ avFrame->height ,
93
+ frameFormat,
94
+ avFrame->sample_aspect_ratio ,
95
+ expectedOutputWidth,
96
+ expectedOutputHeight};
97
+
98
+ // By default, we want to use swscale for color conversion because it is
99
+ // faster. However, it has width requirements, so we may need to fall back
100
+ // to filtergraph. We also need to respect what was requested from the
101
+ // options; we respect the options unconditionally, so it's possible for
102
+ // swscale's width requirements to be violated. We don't expose the ability to
103
+ // choose color conversion library publicly; we only use this ability
104
+ // internally.
105
+
106
+ // swscale requires widths to be multiples of 32:
107
+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
108
+ // so we fall back to filtergraph if the width is not a multiple of 32.
109
+ auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
110
+ ? ColorConversionLibrary::SWSCALE
111
+ : ColorConversionLibrary::FILTERGRAPH;
112
+
113
+ ColorConversionLibrary colorConversionLibrary =
114
+ videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
115
+
116
+ if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
117
+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
118
+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
119
+
120
+ if (!swsContext_ || prevFrameContext_ != frameContext) {
121
+ createSwsContext (frameContext, avFrame->colorspace );
122
+ prevFrameContext_ = frameContext;
123
+ }
124
+ int resultHeight =
125
+ convertAVFrameToTensorUsingSwsScale (avFrame, outputTensor);
126
+ // If this check failed, it would mean that the frame wasn't reshaped to
127
+ // the expected height.
128
+ // TODO: Can we do the same check for width?
129
+ TORCH_CHECK (
130
+ resultHeight == expectedOutputHeight,
131
+ " resultHeight != expectedOutputHeight: " ,
132
+ resultHeight,
133
+ " != " ,
134
+ expectedOutputHeight);
135
+
136
+ frameOutput.data = outputTensor;
137
+ } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
138
+ if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
139
+ createFilterGraph (frameContext, videoStreamOptions);
140
+ prevFrameContext_ = frameContext;
141
+ }
142
+ outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
143
+
144
+ // Similarly to above, if this check fails it means the frame wasn't
145
+ // reshaped to its expected dimensions by filtergraph.
146
+ auto shape = outputTensor.sizes ();
147
+ TORCH_CHECK (
148
+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
149
+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
150
+ " Expected output tensor of shape " ,
151
+ expectedOutputHeight,
152
+ " x" ,
153
+ expectedOutputWidth,
154
+ " x3, got " ,
155
+ shape);
156
+
157
+ if (preAllocatedOutputTensor.has_value ()) {
158
+ // We have already validated that preAllocatedOutputTensor and
159
+ // outputTensor have the same shape.
160
+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
161
+ frameOutput.data = preAllocatedOutputTensor.value ();
162
+ } else {
163
+ frameOutput.data = outputTensor;
164
+ }
165
+ } else {
166
+ throw std::runtime_error (
167
+ " Invalid color conversion library: " +
168
+ std::to_string (static_cast <int >(colorConversionLibrary)));
169
+ }
170
+ }
171
+
172
+ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale (
173
+ const UniqueAVFrame& avFrame,
174
+ torch::Tensor& outputTensor) {
175
+ uint8_t * pointers[4 ] = {
176
+ outputTensor.data_ptr <uint8_t >(), nullptr , nullptr , nullptr };
177
+ int expectedOutputWidth = outputTensor.sizes ()[1 ];
178
+ int linesizes[4 ] = {expectedOutputWidth * 3 , 0 , 0 , 0 };
179
+ int resultHeight = sws_scale (
180
+ swsContext_.get (),
181
+ avFrame->data ,
182
+ avFrame->linesize ,
183
+ 0 ,
184
+ avFrame->height ,
185
+ pointers,
186
+ linesizes);
187
+ return resultHeight;
188
+ }
189
+
190
+ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
191
+ const UniqueAVFrame& avFrame) {
192
+ int status = av_buffersrc_write_frame (
193
+ filterGraphContext_.sourceContext , avFrame.get ());
194
+ if (status < AVSUCCESS) {
195
+ throw std::runtime_error (" Failed to add frame to buffer source context" );
196
+ }
197
+
198
+ UniqueAVFrame filteredAVFrame (av_frame_alloc ());
199
+ status = av_buffersink_get_frame (
200
+ filterGraphContext_.sinkContext , filteredAVFrame.get ());
201
+ TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
202
+
203
+ auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
204
+ int height = frameDims.height ;
205
+ int width = frameDims.width ;
206
+ std::vector<int64_t > shape = {height, width, 3 };
207
+ std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
208
+ AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
209
+ auto deleter = [filteredAVFramePtr](void *) {
210
+ UniqueAVFrame avFrameToDelete (filteredAVFramePtr);
211
+ };
212
+ return torch::from_blob (
213
+ filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
214
+ }
215
+
216
+ void CpuDeviceInterface::createFilterGraph (
217
+ const DecodedFrameContext& frameContext,
218
+ const VideoStreamOptions& videoStreamOptions) {
219
+ filterGraphContext_.filterGraph .reset (avfilter_graph_alloc ());
220
+ TORCH_CHECK (filterGraphContext_.filterGraph .get () != nullptr );
221
+
222
+ if (videoStreamOptions.ffmpegThreadCount .has_value ()) {
223
+ filterGraphContext_.filterGraph ->nb_threads =
224
+ videoStreamOptions.ffmpegThreadCount .value ();
225
+ }
226
+
227
+ const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
228
+ const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
229
+
230
+ std::stringstream filterArgs;
231
+ filterArgs << " video_size=" << frameContext.decodedWidth << " x"
232
+ << frameContext.decodedHeight ;
233
+ filterArgs << " :pix_fmt=" << frameContext.decodedFormat ;
234
+ filterArgs << " :time_base=" << timeBase_.num << " /" << timeBase_.den ;
235
+ filterArgs << " :pixel_aspect=" << frameContext.decodedAspectRatio .num << " /"
236
+ << frameContext.decodedAspectRatio .den ;
237
+
238
+ int status = avfilter_graph_create_filter (
239
+ &filterGraphContext_.sourceContext ,
240
+ buffersrc,
241
+ " in" ,
242
+ filterArgs.str ().c_str (),
243
+ nullptr ,
244
+ filterGraphContext_.filterGraph .get ());
245
+ if (status < 0 ) {
246
+ throw std::runtime_error (
247
+ std::string (" Failed to create filter graph: " ) + filterArgs.str () +
248
+ " : " + getFFMPEGErrorStringFromErrorCode (status));
249
+ }
250
+
251
+ status = avfilter_graph_create_filter (
252
+ &filterGraphContext_.sinkContext ,
253
+ buffersink,
254
+ " out" ,
255
+ nullptr ,
256
+ nullptr ,
257
+ filterGraphContext_.filterGraph .get ());
258
+ if (status < 0 ) {
259
+ throw std::runtime_error (
260
+ " Failed to create filter graph: " +
261
+ getFFMPEGErrorStringFromErrorCode (status));
262
+ }
263
+
264
+ enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
265
+
266
+ status = av_opt_set_int_list (
267
+ filterGraphContext_.sinkContext ,
268
+ " pix_fmts" ,
269
+ pix_fmts,
270
+ AV_PIX_FMT_NONE,
271
+ AV_OPT_SEARCH_CHILDREN);
272
+ if (status < 0 ) {
273
+ throw std::runtime_error (
274
+ " Failed to set output pixel formats: " +
275
+ getFFMPEGErrorStringFromErrorCode (status));
276
+ }
277
+
278
+ UniqueAVFilterInOut outputs (avfilter_inout_alloc ());
279
+ UniqueAVFilterInOut inputs (avfilter_inout_alloc ());
280
+
281
+ outputs->name = av_strdup (" in" );
282
+ outputs->filter_ctx = filterGraphContext_.sourceContext ;
283
+ outputs->pad_idx = 0 ;
284
+ outputs->next = nullptr ;
285
+ inputs->name = av_strdup (" out" );
286
+ inputs->filter_ctx = filterGraphContext_.sinkContext ;
287
+ inputs->pad_idx = 0 ;
288
+ inputs->next = nullptr ;
289
+
290
+ std::stringstream description;
291
+ description << " scale=" << frameContext.expectedWidth << " :"
292
+ << frameContext.expectedHeight ;
293
+ description << " :sws_flags=bilinear" ;
294
+
295
+ AVFilterInOut* outputsTmp = outputs.release ();
296
+ AVFilterInOut* inputsTmp = inputs.release ();
297
+ status = avfilter_graph_parse_ptr (
298
+ filterGraphContext_.filterGraph .get (),
299
+ description.str ().c_str (),
300
+ &inputsTmp,
301
+ &outputsTmp,
302
+ nullptr );
303
+ outputs.reset (outputsTmp);
304
+ inputs.reset (inputsTmp);
305
+ if (status < 0 ) {
306
+ throw std::runtime_error (
307
+ " Failed to parse filter description: " +
308
+ getFFMPEGErrorStringFromErrorCode (status));
309
+ }
310
+
311
+ status =
312
+ avfilter_graph_config (filterGraphContext_.filterGraph .get (), nullptr );
313
+ if (status < 0 ) {
314
+ throw std::runtime_error (
315
+ " Failed to configure filter graph: " +
316
+ getFFMPEGErrorStringFromErrorCode (status));
317
+ }
318
+ }
319
+
320
+ void CpuDeviceInterface::createSwsContext (
321
+ const DecodedFrameContext& frameContext,
322
+ const enum AVColorSpace colorspace) {
323
+ SwsContext* swsContext = sws_getContext (
324
+ frameContext.decodedWidth ,
325
+ frameContext.decodedHeight ,
326
+ frameContext.decodedFormat ,
327
+ frameContext.expectedWidth ,
328
+ frameContext.expectedHeight ,
329
+ AV_PIX_FMT_RGB24,
330
+ SWS_BILINEAR,
331
+ nullptr ,
332
+ nullptr ,
333
+ nullptr );
334
+ TORCH_CHECK (swsContext, " sws_getContext() returned nullptr" );
335
+
336
+ int * invTable = nullptr ;
337
+ int * table = nullptr ;
338
+ int srcRange, dstRange, brightness, contrast, saturation;
339
+ int ret = sws_getColorspaceDetails (
340
+ swsContext,
341
+ &invTable,
342
+ &srcRange,
343
+ &table,
344
+ &dstRange,
345
+ &brightness,
346
+ &contrast,
347
+ &saturation);
348
+ TORCH_CHECK (ret != -1 , " sws_getColorspaceDetails returned -1" );
349
+
350
+ const int * colorspaceTable = sws_getCoefficients (colorspace);
351
+ ret = sws_setColorspaceDetails (
352
+ swsContext,
353
+ colorspaceTable,
354
+ srcRange,
355
+ colorspaceTable,
356
+ dstRange,
357
+ brightness,
358
+ contrast,
359
+ saturation);
360
+ TORCH_CHECK (ret != -1 , " sws_setColorspaceDetails returned -1" );
361
+
362
+ swsContext_.reset (swsContext);
363
+ }
364
+
29
365
} // namespace facebook::torchcodec
0 commit comments