|
| 1 | +// Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the BSD-style license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +#include "src/torchcodec/_core/CpuDeviceInterface.h" |
| 8 | + |
| 9 | +extern "C" { |
| 10 | +#include <libavfilter/buffersink.h> |
| 11 | +#include <libavfilter/buffersrc.h> |
| 12 | +} |
| 13 | + |
| 14 | +namespace facebook::torchcodec { |
| 15 | +namespace { |
| 16 | + |
| 17 | +static bool g_cpu = registerDeviceInterface( |
| 18 | + torch::kCPU, |
| 19 | + [](const torch::Device& device) { return new CpuDeviceInterface(device); }); |
| 20 | + |
| 21 | +} // namespace |
| 22 | + |
| 23 | +bool CpuDeviceInterface::DecodedFrameContext::operator==( |
| 24 | + const CpuDeviceInterface::DecodedFrameContext& other) { |
| 25 | + return decodedWidth == other.decodedWidth && |
| 26 | + decodedHeight == other.decodedHeight && |
| 27 | + decodedFormat == other.decodedFormat && |
| 28 | + expectedWidth == other.expectedWidth && |
| 29 | + expectedHeight == other.expectedHeight; |
| 30 | +} |
| 31 | + |
| 32 | +bool CpuDeviceInterface::DecodedFrameContext::operator!=( |
| 33 | + const CpuDeviceInterface::DecodedFrameContext& other) { |
| 34 | + return !(*this == other); |
| 35 | +} |
| 36 | + |
| 37 | +CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) |
| 38 | + : DeviceInterface(device) { |
| 39 | + TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); |
| 40 | + if (device_.type() != torch::kCPU) { |
| 41 | + throw std::runtime_error("Unsupported device: " + device_.str()); |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +// Note [preAllocatedOutputTensor with swscale and filtergraph]: |
| 46 | +// Callers may pass a pre-allocated tensor, where the output.data tensor will |
| 47 | +// be stored. This parameter is honored in any case, but it only leads to a |
| 48 | +// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the |
| 49 | +// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet |
| 50 | +// found a way to do that with filtegraph. |
| 51 | +// TODO: Figure out whether that's possible! |
| 52 | +// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of |
| 53 | +// `dimension_order` parameter. It's up to callers to re-shape it if needed. |
| 54 | +void CpuDeviceInterface::convertAVFrameToFrameOutput( |
| 55 | + const VideoStreamOptions& videoStreamOptions, |
| 56 | + const AVRational& timeBase, |
| 57 | + UniqueAVFrame& avFrame, |
| 58 | + FrameOutput& frameOutput, |
| 59 | + std::optional<torch::Tensor> preAllocatedOutputTensor) { |
| 60 | + auto frameDims = |
| 61 | + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); |
| 62 | + int expectedOutputHeight = frameDims.height; |
| 63 | + int expectedOutputWidth = frameDims.width; |
| 64 | + |
| 65 | + if (preAllocatedOutputTensor.has_value()) { |
| 66 | + auto shape = preAllocatedOutputTensor.value().sizes(); |
| 67 | + TORCH_CHECK( |
| 68 | + (shape.size() == 3) && (shape[0] == expectedOutputHeight) && |
| 69 | + (shape[1] == expectedOutputWidth) && (shape[2] == 3), |
| 70 | + "Expected pre-allocated tensor of shape ", |
| 71 | + expectedOutputHeight, |
| 72 | + "x", |
| 73 | + expectedOutputWidth, |
| 74 | + "x3, got ", |
| 75 | + shape); |
| 76 | + } |
| 77 | + |
| 78 | + torch::Tensor outputTensor; |
| 79 | + // We need to compare the current frame context with our previous frame |
| 80 | + // context. If they are different, then we need to re-create our colorspace |
| 81 | + // conversion objects. We create our colorspace conversion objects late so |
| 82 | + // that we don't have to depend on the unreliable metadata in the header. |
| 83 | + // And we sometimes re-create them because it's possible for frame |
| 84 | + // resolution to change mid-stream. Finally, we want to reuse the colorspace |
| 85 | + // conversion objects as much as possible for performance reasons. |
| 86 | + enum AVPixelFormat frameFormat = |
| 87 | + static_cast<enum AVPixelFormat>(avFrame->format); |
| 88 | + auto frameContext = DecodedFrameContext{ |
| 89 | + avFrame->width, |
| 90 | + avFrame->height, |
| 91 | + frameFormat, |
| 92 | + avFrame->sample_aspect_ratio, |
| 93 | + expectedOutputWidth, |
| 94 | + expectedOutputHeight}; |
| 95 | + |
| 96 | + // By default, we want to use swscale for color conversion because it is |
| 97 | + // faster. However, it has width requirements, so we may need to fall back |
| 98 | + // to filtergraph. We also need to respect what was requested from the |
| 99 | + // options; we respect the options unconditionally, so it's possible for |
| 100 | + // swscale's width requirements to be violated. We don't expose the ability to |
| 101 | + // choose color conversion library publicly; we only use this ability |
| 102 | + // internally. |
| 103 | + |
| 104 | + // swscale requires widths to be multiples of 32: |
| 105 | + // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements |
| 106 | + // so we fall back to filtergraph if the width is not a multiple of 32. |
| 107 | + auto defaultLibrary = (expectedOutputWidth % 32 == 0) |
| 108 | + ? ColorConversionLibrary::SWSCALE |
| 109 | + : ColorConversionLibrary::FILTERGRAPH; |
| 110 | + |
| 111 | + ColorConversionLibrary colorConversionLibrary = |
| 112 | + videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); |
| 113 | + |
| 114 | + if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { |
| 115 | + outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( |
| 116 | + expectedOutputHeight, expectedOutputWidth, torch::kCPU)); |
| 117 | + |
| 118 | + if (!swsContext_ || prevFrameContext_ != frameContext) { |
| 119 | + createSwsContext(frameContext, avFrame->colorspace); |
| 120 | + prevFrameContext_ = frameContext; |
| 121 | + } |
| 122 | + int resultHeight = |
| 123 | + convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); |
| 124 | + // If this check failed, it would mean that the frame wasn't reshaped to |
| 125 | + // the expected height. |
| 126 | + // TODO: Can we do the same check for width? |
| 127 | + TORCH_CHECK( |
| 128 | + resultHeight == expectedOutputHeight, |
| 129 | + "resultHeight != expectedOutputHeight: ", |
| 130 | + resultHeight, |
| 131 | + " != ", |
| 132 | + expectedOutputHeight); |
| 133 | + |
| 134 | + frameOutput.data = outputTensor; |
| 135 | + } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { |
| 136 | + if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) { |
| 137 | + createFilterGraph(frameContext, videoStreamOptions, timeBase); |
| 138 | + prevFrameContext_ = frameContext; |
| 139 | + } |
| 140 | + outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); |
| 141 | + |
| 142 | + // Similarly to above, if this check fails it means the frame wasn't |
| 143 | + // reshaped to its expected dimensions by filtergraph. |
| 144 | + auto shape = outputTensor.sizes(); |
| 145 | + TORCH_CHECK( |
| 146 | + (shape.size() == 3) && (shape[0] == expectedOutputHeight) && |
| 147 | + (shape[1] == expectedOutputWidth) && (shape[2] == 3), |
| 148 | + "Expected output tensor of shape ", |
| 149 | + expectedOutputHeight, |
| 150 | + "x", |
| 151 | + expectedOutputWidth, |
| 152 | + "x3, got ", |
| 153 | + shape); |
| 154 | + |
| 155 | + if (preAllocatedOutputTensor.has_value()) { |
| 156 | + // We have already validated that preAllocatedOutputTensor and |
| 157 | + // outputTensor have the same shape. |
| 158 | + preAllocatedOutputTensor.value().copy_(outputTensor); |
| 159 | + frameOutput.data = preAllocatedOutputTensor.value(); |
| 160 | + } else { |
| 161 | + frameOutput.data = outputTensor; |
| 162 | + } |
| 163 | + } else { |
| 164 | + throw std::runtime_error( |
| 165 | + "Invalid color conversion library: " + |
| 166 | + std::to_string(static_cast<int>(colorConversionLibrary))); |
| 167 | + } |
| 168 | +} |
| 169 | + |
| 170 | +int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( |
| 171 | + const UniqueAVFrame& avFrame, |
| 172 | + torch::Tensor& outputTensor) { |
| 173 | + uint8_t* pointers[4] = { |
| 174 | + outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr}; |
| 175 | + int expectedOutputWidth = outputTensor.sizes()[1]; |
| 176 | + int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; |
| 177 | + int resultHeight = sws_scale( |
| 178 | + swsContext_.get(), |
| 179 | + avFrame->data, |
| 180 | + avFrame->linesize, |
| 181 | + 0, |
| 182 | + avFrame->height, |
| 183 | + pointers, |
| 184 | + linesizes); |
| 185 | + return resultHeight; |
| 186 | +} |
| 187 | + |
| 188 | +torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( |
| 189 | + const UniqueAVFrame& avFrame) { |
| 190 | + int status = av_buffersrc_write_frame( |
| 191 | + filterGraphContext_.sourceContext, avFrame.get()); |
| 192 | + if (status < AVSUCCESS) { |
| 193 | + throw std::runtime_error("Failed to add frame to buffer source context"); |
| 194 | + } |
| 195 | + |
| 196 | + UniqueAVFrame filteredAVFrame(av_frame_alloc()); |
| 197 | + status = av_buffersink_get_frame( |
| 198 | + filterGraphContext_.sinkContext, filteredAVFrame.get()); |
| 199 | + TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); |
| 200 | + |
| 201 | + auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); |
| 202 | + int height = frameDims.height; |
| 203 | + int width = frameDims.width; |
| 204 | + std::vector<int64_t> shape = {height, width, 3}; |
| 205 | + std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1}; |
| 206 | + AVFrame* filteredAVFramePtr = filteredAVFrame.release(); |
| 207 | + auto deleter = [filteredAVFramePtr](void*) { |
| 208 | + UniqueAVFrame avFrameToDelete(filteredAVFramePtr); |
| 209 | + }; |
| 210 | + return torch::from_blob( |
| 211 | + filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); |
| 212 | +} |
| 213 | + |
| 214 | +void CpuDeviceInterface::createFilterGraph( |
| 215 | + const DecodedFrameContext& frameContext, |
| 216 | + const VideoStreamOptions& videoStreamOptions, |
| 217 | + const AVRational& timeBase) { |
| 218 | + filterGraphContext_.filterGraph.reset(avfilter_graph_alloc()); |
| 219 | + TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr); |
| 220 | + |
| 221 | + if (videoStreamOptions.ffmpegThreadCount.has_value()) { |
| 222 | + filterGraphContext_.filterGraph->nb_threads = |
| 223 | + videoStreamOptions.ffmpegThreadCount.value(); |
| 224 | + } |
| 225 | + |
| 226 | + const AVFilter* buffersrc = avfilter_get_by_name("buffer"); |
| 227 | + const AVFilter* buffersink = avfilter_get_by_name("buffersink"); |
| 228 | + |
| 229 | + std::stringstream filterArgs; |
| 230 | + filterArgs << "video_size=" << frameContext.decodedWidth << "x" |
| 231 | + << frameContext.decodedHeight; |
| 232 | + filterArgs << ":pix_fmt=" << frameContext.decodedFormat; |
| 233 | + filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den; |
| 234 | + filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/" |
| 235 | + << frameContext.decodedAspectRatio.den; |
| 236 | + |
| 237 | + int status = avfilter_graph_create_filter( |
| 238 | + &filterGraphContext_.sourceContext, |
| 239 | + buffersrc, |
| 240 | + "in", |
| 241 | + filterArgs.str().c_str(), |
| 242 | + nullptr, |
| 243 | + filterGraphContext_.filterGraph.get()); |
| 244 | + if (status < 0) { |
| 245 | + throw std::runtime_error( |
| 246 | + std::string("Failed to create filter graph: ") + filterArgs.str() + |
| 247 | + ": " + getFFMPEGErrorStringFromErrorCode(status)); |
| 248 | + } |
| 249 | + |
| 250 | + status = avfilter_graph_create_filter( |
| 251 | + &filterGraphContext_.sinkContext, |
| 252 | + buffersink, |
| 253 | + "out", |
| 254 | + nullptr, |
| 255 | + nullptr, |
| 256 | + filterGraphContext_.filterGraph.get()); |
| 257 | + if (status < 0) { |
| 258 | + throw std::runtime_error( |
| 259 | + "Failed to create filter graph: " + |
| 260 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 261 | + } |
| 262 | + |
| 263 | + enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; |
| 264 | + |
| 265 | + status = av_opt_set_int_list( |
| 266 | + filterGraphContext_.sinkContext, |
| 267 | + "pix_fmts", |
| 268 | + pix_fmts, |
| 269 | + AV_PIX_FMT_NONE, |
| 270 | + AV_OPT_SEARCH_CHILDREN); |
| 271 | + if (status < 0) { |
| 272 | + throw std::runtime_error( |
| 273 | + "Failed to set output pixel formats: " + |
| 274 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 275 | + } |
| 276 | + |
| 277 | + UniqueAVFilterInOut outputs(avfilter_inout_alloc()); |
| 278 | + UniqueAVFilterInOut inputs(avfilter_inout_alloc()); |
| 279 | + |
| 280 | + outputs->name = av_strdup("in"); |
| 281 | + outputs->filter_ctx = filterGraphContext_.sourceContext; |
| 282 | + outputs->pad_idx = 0; |
| 283 | + outputs->next = nullptr; |
| 284 | + inputs->name = av_strdup("out"); |
| 285 | + inputs->filter_ctx = filterGraphContext_.sinkContext; |
| 286 | + inputs->pad_idx = 0; |
| 287 | + inputs->next = nullptr; |
| 288 | + |
| 289 | + std::stringstream description; |
| 290 | + description << "scale=" << frameContext.expectedWidth << ":" |
| 291 | + << frameContext.expectedHeight; |
| 292 | + description << ":sws_flags=bilinear"; |
| 293 | + |
| 294 | + AVFilterInOut* outputsTmp = outputs.release(); |
| 295 | + AVFilterInOut* inputsTmp = inputs.release(); |
| 296 | + status = avfilter_graph_parse_ptr( |
| 297 | + filterGraphContext_.filterGraph.get(), |
| 298 | + description.str().c_str(), |
| 299 | + &inputsTmp, |
| 300 | + &outputsTmp, |
| 301 | + nullptr); |
| 302 | + outputs.reset(outputsTmp); |
| 303 | + inputs.reset(inputsTmp); |
| 304 | + if (status < 0) { |
| 305 | + throw std::runtime_error( |
| 306 | + "Failed to parse filter description: " + |
| 307 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 308 | + } |
| 309 | + |
| 310 | + status = |
| 311 | + avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr); |
| 312 | + if (status < 0) { |
| 313 | + throw std::runtime_error( |
| 314 | + "Failed to configure filter graph: " + |
| 315 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 316 | + } |
| 317 | +} |
| 318 | + |
| 319 | +void CpuDeviceInterface::createSwsContext( |
| 320 | + const DecodedFrameContext& frameContext, |
| 321 | + const enum AVColorSpace colorspace) { |
| 322 | + SwsContext* swsContext = sws_getContext( |
| 323 | + frameContext.decodedWidth, |
| 324 | + frameContext.decodedHeight, |
| 325 | + frameContext.decodedFormat, |
| 326 | + frameContext.expectedWidth, |
| 327 | + frameContext.expectedHeight, |
| 328 | + AV_PIX_FMT_RGB24, |
| 329 | + SWS_BILINEAR, |
| 330 | + nullptr, |
| 331 | + nullptr, |
| 332 | + nullptr); |
| 333 | + TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); |
| 334 | + |
| 335 | + int* invTable = nullptr; |
| 336 | + int* table = nullptr; |
| 337 | + int srcRange, dstRange, brightness, contrast, saturation; |
| 338 | + int ret = sws_getColorspaceDetails( |
| 339 | + swsContext, |
| 340 | + &invTable, |
| 341 | + &srcRange, |
| 342 | + &table, |
| 343 | + &dstRange, |
| 344 | + &brightness, |
| 345 | + &contrast, |
| 346 | + &saturation); |
| 347 | + TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); |
| 348 | + |
| 349 | + const int* colorspaceTable = sws_getCoefficients(colorspace); |
| 350 | + ret = sws_setColorspaceDetails( |
| 351 | + swsContext, |
| 352 | + colorspaceTable, |
| 353 | + srcRange, |
| 354 | + colorspaceTable, |
| 355 | + dstRange, |
| 356 | + brightness, |
| 357 | + contrast, |
| 358 | + saturation); |
| 359 | + TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); |
| 360 | + |
| 361 | + swsContext_.reset(swsContext); |
| 362 | +} |
| 363 | + |
| 364 | +} // namespace facebook::torchcodec |
0 commit comments