Skip to content

Commit daf3631

Browse files
authored
Refac: Straightforward output shape permutation (#317)
1 parent ddf3e54 commit daf3631

File tree

7 files changed

+86
-69
lines changed

7 files changed

+86
-69
lines changed

benchmarks/decoders/BenchmarkDecodersMain.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void runNDecodeIterations(
6363
decoder->addVideoStreamDecoder(-1);
6464
for (double pts : ptsList) {
6565
decoder->setCursorPtsInSeconds(pts);
66-
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
66+
torch::Tensor tensor = decoder->getNextFrameNoDemux().frame;
6767
}
6868
if (i + 1 == warmupIterations) {
6969
start = std::chrono::high_resolution_clock::now();
@@ -95,7 +95,7 @@ void runNdecodeIterationsGrabbingConsecutiveFrames(
9595
VideoDecoder::createFromFilePath(videoPath);
9696
decoder->addVideoStreamDecoder(-1);
9797
for (int j = 0; j < consecutiveFrameCount; ++j) {
98-
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
98+
torch::Tensor tensor = decoder->getNextFrameNoDemux().frame;
9999
}
100100
if (i + 1 == warmupIterations) {
101101
start = std::chrono::high_resolution_clock::now();

packaging/check_glibcxx.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@
4646
all_symbols.add(match.group(0))
4747

4848
if not all_symbols:
49-
raise ValueError(f"No GLIBCXX symbols found in {symbol_matches}. Something is wrong.")
49+
raise ValueError(
50+
f"No GLIBCXX symbols found in {symbol_matches}. Something is wrong."
51+
)
5052

5153
all_versions = (symbol.split("_")[1].split(".") for symbol in all_symbols)
5254
all_versions = (tuple(int(v) for v in version) for version in all_versions)

src/torchcodec/decoders/_core/CudaDevice.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,6 @@ void convertAVFrameToDecodedOutputOnCuda(
240240
std::chrono::duration<double, std::micro> duration = end - start;
241241
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
242242
<< " took: " << duration.count() << "us" << std::endl;
243-
if (options.dimensionOrder == "NCHW") {
244-
// The docs guaranty this to return a view:
245-
// https://pytorch.org/docs/stable/generated/torch.permute.html
246-
dst = dst.permute({2, 0, 1});
247-
}
248243
}
249244

250245
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoder.cpp

+56-44
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,6 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
3434
return ptsToSeconds(pts, timeBase.den);
3535
}
3636

37-
// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
38-
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
39-
// or 4D.
40-
// Calling permute() is guaranteed to return a view as per the docs:
41-
// https://pytorch.org/docs/stable/generated/torch.permute.html
42-
torch::Tensor MaybePermuteHWC2CHW(
43-
const VideoDecoder::VideoStreamDecoderOptions& options,
44-
torch::Tensor& hwcTensor) {
45-
if (options.dimensionOrder == "NHWC") {
46-
return hwcTensor;
47-
}
48-
auto numDimensions = hwcTensor.dim();
49-
auto shape = hwcTensor.sizes();
50-
if (numDimensions == 3) {
51-
TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
52-
return hwcTensor.permute({2, 0, 1});
53-
} else if (numDimensions == 4) {
54-
TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape);
55-
return hwcTensor.permute({0, 3, 1, 2});
56-
} else {
57-
TORCH_CHECK(
58-
false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions);
59-
}
60-
}
61-
6237
struct AVInput {
6338
UniqueAVFormatContext formatContext;
6439
std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -136,6 +111,31 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth(
136111

137112
} // namespace
138113

114+
// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
115+
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
116+
// or 4D.
117+
// Calling permute() is guaranteed to return a view as per the docs:
118+
// https://pytorch.org/docs/stable/generated/torch.permute.html
119+
torch::Tensor VideoDecoder::MaybePermuteHWC2CHW(
120+
int streamIndex,
121+
torch::Tensor& hwcTensor) {
122+
if (streams_[streamIndex].options.dimensionOrder == "NHWC") {
123+
return hwcTensor;
124+
}
125+
auto numDimensions = hwcTensor.dim();
126+
auto shape = hwcTensor.sizes();
127+
if (numDimensions == 3) {
128+
TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
129+
return hwcTensor.permute({2, 0, 1});
130+
} else if (numDimensions == 4) {
131+
TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape);
132+
return hwcTensor.permute({0, 3, 1, 2});
133+
} else {
134+
TORCH_CHECK(
135+
false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions);
136+
}
137+
}
138+
139139
VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
140140
const std::string& optionsString) {
141141
std::vector<std::string> tokens =
@@ -929,14 +929,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
929929
"Invalid color conversion library: " +
930930
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
931931
}
932-
if (!preAllocatedOutputTensor.has_value()) {
933-
// We only convert to CHW if a pre-allocated tensor wasn't passed. When a
934-
// pre-allocated tensor is passed, it's up to the caller (typically a
935-
// batch API) to do the conversion. This is more efficient as it allows
936-
// batch NHWC tensors to be permuted only once, instead of permuting HWC
937-
// tensors N times.
938-
output.frame = MaybePermuteHWC2CHW(streamInfo.options, output.frame);
939-
}
940932

941933
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
942934
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
@@ -978,7 +970,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
978970
return seconds >= frameStartTime && seconds < frameEndTime;
979971
});
980972
// Convert the frame to tensor.
981-
return convertAVFrameToDecodedOutput(rawOutput);
973+
auto output = convertAVFrameToDecodedOutput(rawOutput);
974+
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
975+
return output;
982976
}
983977

984978
void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) {
@@ -1015,6 +1009,16 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10151009
int streamIndex,
10161010
int64_t frameIndex,
10171011
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1012+
auto output = getFrameAtIndexInternal(
1013+
streamIndex, frameIndex, preAllocatedOutputTensor);
1014+
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame);
1015+
return output;
1016+
}
1017+
1018+
VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(
1019+
int streamIndex,
1020+
int64_t frameIndex,
1021+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
10181022
validateUserProvidedStreamIndex(streamIndex);
10191023
validateScannedAllStreams("getFrameAtIndex");
10201024

@@ -1023,7 +1027,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10231027

10241028
int64_t pts = stream.allFrames[frameIndex].pts;
10251029
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
1026-
return getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
1030+
return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor);
10271031
}
10281032

10291033
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
@@ -1073,14 +1077,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10731077
output.durationSeconds[indexInOutput] =
10741078
output.durationSeconds[previousIndexInOutput];
10751079
} else {
1076-
DecodedOutput singleOut = getFrameAtIndex(
1080+
DecodedOutput singleOut = getFrameAtIndexInternal(
10771081
streamIndex, indexInVideo, output.frames[indexInOutput]);
10781082
output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
10791083
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
10801084
}
10811085
previousIndexInVideo = indexInVideo;
10821086
}
1083-
output.frames = MaybePermuteHWC2CHW(options, output.frames);
1087+
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
10841088
return output;
10851089
}
10861090

@@ -1150,11 +1154,12 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
11501154
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);
11511155

11521156
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
1153-
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1157+
DecodedOutput singleOut =
1158+
getFrameAtIndexInternal(streamIndex, i, output.frames[f]);
11541159
output.ptsSeconds[f] = singleOut.ptsSeconds;
11551160
output.durationSeconds[f] = singleOut.durationSeconds;
11561161
}
1157-
output.frames = MaybePermuteHWC2CHW(options, output.frames);
1162+
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
11581163
return output;
11591164
}
11601165

@@ -1207,7 +1212,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12071212
// need this special case below.
12081213
if (startSeconds == stopSeconds) {
12091214
BatchDecodedOutput output(0, options, streamMetadata);
1210-
output.frames = MaybePermuteHWC2CHW(options, output.frames);
1215+
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
12111216
return output;
12121217
}
12131218

@@ -1243,11 +1248,12 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12431248
int64_t numFrames = stopFrameIndex - startFrameIndex;
12441249
BatchDecodedOutput output(numFrames, options, streamMetadata);
12451250
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
1246-
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1251+
DecodedOutput singleOut =
1252+
getFrameAtIndexInternal(streamIndex, i, output.frames[f]);
12471253
output.ptsSeconds[f] = singleOut.ptsSeconds;
12481254
output.durationSeconds[f] = singleOut.durationSeconds;
12491255
}
1250-
output.frames = MaybePermuteHWC2CHW(options, output.frames);
1256+
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
12511257

12521258
return output;
12531259
}
@@ -1261,7 +1267,13 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
12611267
return rawOutput;
12621268
}
12631269

1264-
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux(
1270+
VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() {
1271+
auto output = getNextFrameOutputNoDemuxInternal();
1272+
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
1273+
return output;
1274+
}
1275+
1276+
VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal(
12651277
std::optional<torch::Tensor> preAllocatedOutputTensor) {
12661278
auto rawOutput = getNextRawDecodedOutputNoDemux();
12671279
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
@@ -1283,7 +1295,7 @@ double VideoDecoder::getPtsSecondsForFrame(
12831295
int streamIndex,
12841296
int64_t frameIndex) {
12851297
validateUserProvidedStreamIndex(streamIndex);
1286-
validateScannedAllStreams("getFrameAtIndex");
1298+
validateScannedAllStreams("getPtsSecondsForFrame");
12871299

12881300
const auto& stream = streams_[streamIndex];
12891301
validateFrameIndex(stream, frameIndex);

src/torchcodec/decoders/_core/VideoDecoder.h

+12-4
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,12 @@ class VideoDecoder {
157157
int streamIndex,
158158
const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions());
159159

160+
torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
161+
160162
// ---- SINGLE FRAME SEEK AND DECODING API ----
161163
// Places the cursor at the first frame on or after the position in seconds.
162-
// Calling getNextDecodedOutputNoDemux() will return the first frame at or
163-
// after this position.
164+
// Calling getNextFrameOutputNoDemuxInternal() will return the first frame at
165+
// or after this position.
164166
void setCursorPtsInSeconds(double seconds);
165167
// This is an internal structure that is used to store the decoded output
166168
// from decoding a frame through color conversion. Example usage is:
@@ -214,8 +216,7 @@ class VideoDecoder {
214216
};
215217
// Decodes the frame where the current cursor position is. It also advances
216218
// the cursor to the next frame.
217-
DecodedOutput getNextDecodedOutputNoDemux(
218-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
219+
DecodedOutput getNextFrameNoDemux();
219220
// Decodes the first frame in any added stream that is visible at a given
220221
// timestamp. Frames in the video have a presentation timestamp and a
221222
// duration. For example, if a frame has presentation timestamp of 5.0s and a
@@ -386,6 +387,13 @@ class VideoDecoder {
386387
DecodedOutput& output,
387388
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
388389

390+
DecodedOutput getFrameAtIndexInternal(
391+
int streamIndex,
392+
int64_t frameIndex,
393+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
394+
DecodedOutput getNextFrameOutputNoDemuxInternal(
395+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
396+
389397
DecoderOptions options_;
390398
ContainerMetadata containerMetadata_;
391399
UniqueAVFormatContext formatContext_;

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) {
193193
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
194194
VideoDecoder::DecodedOutput result;
195195
try {
196-
result = videoDecoder->getNextDecodedOutputNoDemux();
196+
result = videoDecoder->getNextFrameNoDemux();
197197
} catch (const VideoDecoder::EndOfFileException& e) {
198198
C10_THROW_ERROR(IndexError, e.what());
199199
}

test/decoders/VideoDecoderTest.cpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) {
148148
streamOptions.width = 100;
149149
streamOptions.height = 120;
150150
decoder->addVideoStreamDecoder(-1, streamOptions);
151-
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
151+
torch::Tensor tensor = decoder->getNextFrameNoDemux().frame;
152152
EXPECT_EQ(tensor.sizes(), std::vector<long>({3, 120, 100}));
153153
}
154154

@@ -159,7 +159,7 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) {
159159
VideoDecoder::VideoStreamDecoderOptions streamOptions;
160160
streamOptions.dimensionOrder = "NHWC";
161161
decoder->addVideoStreamDecoder(-1, streamOptions);
162-
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
162+
torch::Tensor tensor = decoder->getNextFrameNoDemux().frame;
163163
EXPECT_EQ(tensor.sizes(), std::vector<long>({270, 480, 3}));
164164
}
165165

@@ -168,12 +168,12 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
168168
std::unique_ptr<VideoDecoder> ourDecoder =
169169
createDecoderFromPath(path, GetParam());
170170
ourDecoder->addVideoStreamDecoder(-1);
171-
auto output = ourDecoder->getNextDecodedOutputNoDemux();
171+
auto output = ourDecoder->getNextFrameNoDemux();
172172
torch::Tensor tensor0FromOurDecoder = output.frame;
173173
EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
174174
EXPECT_EQ(output.ptsSeconds, 0.0);
175175
EXPECT_EQ(output.pts, 0);
176-
output = ourDecoder->getNextDecodedOutputNoDemux();
176+
output = ourDecoder->getNextFrameNoDemux();
177177
torch::Tensor tensor1FromOurDecoder = output.frame;
178178
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
179179
EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000);
@@ -254,11 +254,11 @@ TEST_P(VideoDecoderTest, SeeksCloseToEof) {
254254
createDecoderFromPath(path, GetParam());
255255
ourDecoder->addVideoStreamDecoder(-1);
256256
ourDecoder->setCursorPtsInSeconds(388388. / 30'000);
257-
auto output = ourDecoder->getNextDecodedOutputNoDemux();
257+
auto output = ourDecoder->getNextFrameNoDemux();
258258
EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000);
259-
output = ourDecoder->getNextDecodedOutputNoDemux();
259+
output = ourDecoder->getNextFrameNoDemux();
260260
EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000);
261-
EXPECT_THROW(ourDecoder->getNextDecodedOutputNoDemux(), std::exception);
261+
EXPECT_THROW(ourDecoder->getNextFrameNoDemux(), std::exception);
262262
}
263263

264264
TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) {
@@ -298,7 +298,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
298298
createDecoderFromPath(path, GetParam());
299299
ourDecoder->addVideoStreamDecoder(-1);
300300
ourDecoder->setCursorPtsInSeconds(6.0);
301-
auto output = ourDecoder->getNextDecodedOutputNoDemux();
301+
auto output = ourDecoder->getNextFrameNoDemux();
302302
torch::Tensor tensor6FromOurDecoder = output.frame;
303303
EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000);
304304
torch::Tensor tensor6FromFFMPEG =
@@ -314,7 +314,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
314314
EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 180);
315315

316316
ourDecoder->setCursorPtsInSeconds(6.1);
317-
output = ourDecoder->getNextDecodedOutputNoDemux();
317+
output = ourDecoder->getNextFrameNoDemux();
318318
torch::Tensor tensor61FromOurDecoder = output.frame;
319319
EXPECT_EQ(output.ptsSeconds, 183'183. / 30'000);
320320
torch::Tensor tensor61FromFFMPEG =
@@ -334,7 +334,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
334334
EXPECT_LT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 10);
335335

336336
ourDecoder->setCursorPtsInSeconds(10.0);
337-
output = ourDecoder->getNextDecodedOutputNoDemux();
337+
output = ourDecoder->getNextFrameNoDemux();
338338
torch::Tensor tensor10FromOurDecoder = output.frame;
339339
EXPECT_EQ(output.ptsSeconds, 300'300. / 30'000);
340340
torch::Tensor tensor10FromFFMPEG =
@@ -351,7 +351,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
351351
EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 60);
352352

353353
ourDecoder->setCursorPtsInSeconds(6.0);
354-
output = ourDecoder->getNextDecodedOutputNoDemux();
354+
output = ourDecoder->getNextFrameNoDemux();
355355
tensor6FromOurDecoder = output.frame;
356356
EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000);
357357
EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG));
@@ -366,7 +366,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
366366

367367
constexpr double kPtsOfLastFrameInVideoStream = 389'389. / 30'000; // ~12.9
368368
ourDecoder->setCursorPtsInSeconds(kPtsOfLastFrameInVideoStream);
369-
output = ourDecoder->getNextDecodedOutputNoDemux();
369+
output = ourDecoder->getNextFrameNoDemux();
370370
torch::Tensor tensor7FromOurDecoder = output.frame;
371371
EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000);
372372
torch::Tensor tensor7FromFFMPEG =

0 commit comments

Comments
 (0)