Skip to content

Commit 373d1c5

Browse files
authored
Towards a clearer logic for determining output height and width (#332)
1 parent ab79e67 commit 373d1c5

File tree

5 files changed

+135
-40
lines changed

5 files changed

+135
-40
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace facebook::torchcodec {
1717
void convertAVFrameToDecodedOutputOnCuda(
1818
const torch::Device& device,
1919
const VideoDecoder::VideoStreamDecoderOptions& options,
20-
AVCodecContext* codecContext,
20+
const VideoDecoder::StreamMetadata& metadata,
2121
VideoDecoder::RawDecodedOutput& rawOutput,
2222
VideoDecoder::DecodedOutput& output,
2323
std::optional<torch::Tensor> preAllocatedOutputTensor) {

src/torchcodec/decoders/_core/CudaDevice.cpp

+5-16
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
154154
#endif
155155
}
156156

157-
torch::Tensor allocateDeviceTensor(
158-
at::IntArrayRef shape,
159-
torch::Device device,
160-
const torch::Dtype dtype = torch::kUInt8) {
161-
return torch::empty(
162-
shape,
163-
torch::TensorOptions()
164-
.dtype(dtype)
165-
.layout(torch::kStrided)
166-
.device(device));
167-
}
168-
169157
void throwErrorIfNonCudaDevice(const torch::Device& device) {
170158
TORCH_CHECK(
171159
device.type() != torch::kCPU,
@@ -199,7 +187,7 @@ void initializeContextOnCuda(
199187
void convertAVFrameToDecodedOutputOnCuda(
200188
const torch::Device& device,
201189
const VideoDecoder::VideoStreamDecoderOptions& options,
202-
AVCodecContext* codecContext,
190+
const VideoDecoder::StreamMetadata& metadata,
203191
VideoDecoder::RawDecodedOutput& rawOutput,
204192
VideoDecoder::DecodedOutput& output,
205193
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -209,8 +197,9 @@ void convertAVFrameToDecodedOutputOnCuda(
209197
src->format == AV_PIX_FMT_CUDA,
210198
"Expected format to be AV_PIX_FMT_CUDA, got " +
211199
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
212-
int width = options.width.value_or(codecContext->width);
213-
int height = options.height.value_or(codecContext->height);
200+
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata);
201+
int height = frameDims.height;
202+
int width = frameDims.width;
214203
NppiSize oSizeROI = {width, height};
215204
Npp8u* input[2] = {src->data[0], src->data[1]};
216205
torch::Tensor& dst = output.frame;
@@ -227,7 +216,7 @@ void convertAVFrameToDecodedOutputOnCuda(
227216
"x3, got ",
228217
shape);
229218
} else {
230-
dst = allocateDeviceTensor({height, width, 3}, options.device);
219+
dst = allocateEmptyHWCTensor(height, width, options.device);
231220
}
232221

233222
// Use the user-requested GPU for running the NPP kernel.

src/torchcodec/decoders/_core/DeviceInterface.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void initializeContextOnCuda(
3535
void convertAVFrameToDecodedOutputOnCuda(
3636
const torch::Device& device,
3737
const VideoDecoder::VideoStreamDecoderOptions& options,
38-
AVCodecContext* codecContext,
38+
const VideoDecoder::StreamMetadata& metadata,
3939
VideoDecoder::RawDecodedOutput& rawOutput,
4040
VideoDecoder::DecodedOutput& output,
4141
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

src/torchcodec/decoders/_core/VideoDecoder.cpp

+64-22
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
195195
int64_t numFrames,
196196
const VideoStreamDecoderOptions& options,
197197
const StreamMetadata& metadata)
198-
: frames(torch::empty(
199-
{numFrames,
200-
options.height.value_or(*metadata.height),
201-
options.width.value_or(*metadata.width),
202-
3},
203-
at::TensorOptions(options.device).dtype(torch::kUInt8))),
204-
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
205-
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}
198+
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
199+
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
200+
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata);
201+
int height = frameDims.height;
202+
int width = frameDims.width;
203+
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
204+
}
206205

207206
VideoDecoder::VideoDecoder() {}
208207

@@ -364,12 +363,11 @@ void VideoDecoder::initializeFilterGraphForStream(
364363
inputs->pad_idx = 0;
365364
inputs->next = nullptr;
366365
char description[512];
367-
int width = activeStream.codecContext->width;
368-
int height = activeStream.codecContext->height;
369-
if (options.height.has_value() && options.width.has_value()) {
370-
width = *options.width;
371-
height = *options.height;
372-
}
366+
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(
367+
options, containerMetadata_.streams[streamIndex]);
368+
int height = frameDims.height;
369+
int width = frameDims.width;
370+
373371
std::snprintf(
374372
description,
375373
sizeof(description),
@@ -869,7 +867,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
869867
convertAVFrameToDecodedOutputOnCuda(
870868
streamInfo.options.device,
871869
streamInfo.options,
872-
streamInfo.codecContext.get(),
870+
containerMetadata_.streams[streamIndex],
873871
rawOutput,
874872
output,
875873
preAllocatedOutputTensor);
@@ -899,8 +897,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
899897
torch::Tensor tensor;
900898
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
901899
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
902-
int width = streamInfo.options.width.value_or(frame->width);
903-
int height = streamInfo.options.height.value_or(frame->height);
900+
auto frameDims =
901+
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame);
902+
int height = frameDims.height;
903+
int width = frameDims.width;
904904
if (preAllocatedOutputTensor.has_value()) {
905905
tensor = preAllocatedOutputTensor.value();
906906
auto shape = tensor.sizes();
@@ -914,8 +914,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
914914
"x3, got ",
915915
shape);
916916
} else {
917-
tensor = torch::empty(
918-
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
917+
tensor = allocateEmptyHWCTensor(height, width, torch::kCPU);
919918
}
920919
rawOutput.data = tensor.data_ptr<uint8_t>();
921920
convertFrameToBufferUsingSwsScale(rawOutput);
@@ -1315,8 +1314,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13151314
enum AVPixelFormat frameFormat =
13161315
static_cast<enum AVPixelFormat>(frame->format);
13171316
StreamInfo& activeStream = streams_[streamIndex];
1318-
int outputWidth = activeStream.options.width.value_or(frame->width);
1319-
int outputHeight = activeStream.options.height.value_or(frame->height);
1317+
auto frameDims =
1318+
getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, *frame);
1319+
int outputHeight = frameDims.height;
1320+
int outputWidth = frameDims.width;
13201321
if (activeStream.swsContext.get() == nullptr) {
13211322
SwsContext* swsContext = sws_getContext(
13221323
frame->width,
@@ -1382,7 +1383,11 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13821383
ffmpegStatus =
13831384
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get());
13841385
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24);
1385-
std::vector<int64_t> shape = {filteredFrame->height, filteredFrame->width, 3};
1386+
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
1387+
streams_[streamIndex].options, *filteredFrame.get());
1388+
int height = frameDims.height;
1389+
int width = frameDims.width;
1390+
std::vector<int64_t> shape = {height, width, 3};
13861391
std::vector<int64_t> strides = {filteredFrame->linesize[0], 3, 1};
13871392
AVFrame* filteredFramePtr = filteredFrame.release();
13881393
auto deleter = [filteredFramePtr](void*) {
@@ -1406,6 +1411,43 @@ VideoDecoder::~VideoDecoder() {
14061411
}
14071412
}
14081413

1414+
FrameDims getHeightAndWidthFromOptionsOrMetadata(
1415+
const VideoDecoder::VideoStreamDecoderOptions& options,
1416+
const VideoDecoder::StreamMetadata& metadata) {
1417+
return FrameDims(
1418+
options.height.value_or(*metadata.height),
1419+
options.width.value_or(*metadata.width));
1420+
}
1421+
1422+
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
1423+
const VideoDecoder::VideoStreamDecoderOptions& options,
1424+
const AVFrame& avFrame) {
1425+
return FrameDims(
1426+
options.height.value_or(avFrame.height),
1427+
options.width.value_or(avFrame.width));
1428+
}
1429+
1430+
torch::Tensor allocateEmptyHWCTensor(
1431+
int height,
1432+
int width,
1433+
torch::Device device,
1434+
std::optional<int> numFrames) {
1435+
auto tensorOptions = torch::TensorOptions()
1436+
.dtype(torch::kUInt8)
1437+
.layout(torch::kStrided)
1438+
.device(device);
1439+
TORCH_CHECK(height > 0, "height must be > 0, got: ", height);
1440+
TORCH_CHECK(width > 0, "width must be > 0, got: ", width);
1441+
if (numFrames.has_value()) {
1442+
auto numFramesValue = numFrames.value();
1443+
TORCH_CHECK(
1444+
numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue);
1445+
return torch::empty({numFramesValue, height, width, 3}, tensorOptions);
1446+
} else {
1447+
return torch::empty({height, width, 3}, tensorOptions);
1448+
}
1449+
}
1450+
14091451
std::ostream& operator<<(
14101452
std::ostream& os,
14111453
const VideoDecoder::DecodeStats& stats) {

src/torchcodec/decoders/_core/VideoDecoder.h

+64
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class VideoDecoder {
243243
const VideoStreamDecoderOptions& options,
244244
const StreamMetadata& metadata);
245245
};
246+
246247
// Returns frames at the given indices for a given stream as a single stacked
247248
// Tensor.
248249
BatchDecodedOutput getFramesAtIndices(
@@ -413,6 +414,69 @@ class VideoDecoder {
413414
bool scanned_all_streams_ = false;
414415
};
415416

417+
// --------------------------------------------------------------------------
418+
// FRAME TENSOR ALLOCATION APIs
419+
// --------------------------------------------------------------------------
420+
421+
// Note [Frame Tensor allocation and height and width]
422+
//
423+
// We always allocate [N]HWC tensors. The low-level decoding functions all
424+
// assume HWC tensors, since this is what FFmpeg natively handles. It's up to
425+
// the high-level decoding entry-points to permute that back to CHW, by calling
426+
// MaybePermuteHWC2CHW().
427+
//
428+
// Also, importantly, the way we figure out the the height and width of the
429+
// output frame varies and depends on the decoding entry-point:
430+
// - In all cases, if the user requested specific height and width from the
431+
// options, we honor that. Otherwise we fall into one of the categories below.
432+
// - In Batch decoding APIs (e.g. getFramesAtIndices), we get height and width
433+
// from the stream metadata, which itself got its value from the CodecContext,
434+
// when the stream was added.
435+
// - In single frames APIs:
436+
// - On CPU we get height and width from the AVFrame.
437+
// - On GPU, we get height and width from the metadata (same as batch APIs)
438+
//
439+
// These 2 strategies are encapsulated within
440+
// getHeightAndWidthFromOptionsOrMetadata() and
441+
// getHeightAndWidthFromOptionsOrAVFrame(). The reason they exist is to make it
442+
// very obvious which logic is used in which place, and they allow for `git
443+
// grep`ing.
444+
//
445+
// The source of truth for height and width really is the AVFrame: it's the
446+
// decoded ouptut from FFmpeg. The info from the metadata (i.e. from the
447+
// CodecContext) may not be as accurate. However, the AVFrame is only available
448+
// late in the call stack, when the frame is decoded, while the CodecContext is
449+
// available early when a stream is added. This is why we use the CodecContext
450+
// for pre-allocating batched output tensors (we could pre-allocate those only
451+
// once we decode the first frame to get the info frame the AVFrame, but that's
452+
// a more complex logic).
453+
//
454+
// Because the sources for height and width may disagree, we may end up with
455+
// conflicts: e.g. if we pre-allocate a batch output tensor based on the
456+
// metadata info, but the decoded AVFrame has a different height and width.
457+
// it is very important to check the height and width assumptions where the
458+
// tensors memory is used/filled in order to avoid segfaults.
459+
460+
struct FrameDims {
461+
int height;
462+
int width;
463+
FrameDims(int h, int w) : height(h), width(w) {}
464+
};
465+
466+
FrameDims getHeightAndWidthFromOptionsOrMetadata(
467+
const VideoDecoder::VideoStreamDecoderOptions& options,
468+
const VideoDecoder::StreamMetadata& metadata);
469+
470+
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
471+
const VideoDecoder::VideoStreamDecoderOptions& options,
472+
const AVFrame& avFrame);
473+
474+
torch::Tensor allocateEmptyHWCTensor(
475+
int height,
476+
int width,
477+
torch::Device device,
478+
std::optional<int> numFrames = std::nullopt);
479+
416480
// Prints the VideoDecoder::DecodeStats to the ostream.
417481
std::ostream& operator<<(
418482
std::ostream& os,

0 commit comments

Comments
 (0)