@@ -34,31 +34,6 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
34
34
return ptsToSeconds (pts, timeBase.den );
35
35
}
36
36
37
- // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
38
- // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
39
- // or 4D.
40
- // Calling permute() is guaranteed to return a view as per the docs:
41
- // https://pytorch.org/docs/stable/generated/torch.permute.html
42
- torch::Tensor MaybePermuteHWC2CHW (
43
- const VideoDecoder::VideoStreamDecoderOptions& options,
44
- torch::Tensor& hwcTensor) {
45
- if (options.dimensionOrder == " NHWC" ) {
46
- return hwcTensor;
47
- }
48
- auto numDimensions = hwcTensor.dim ();
49
- auto shape = hwcTensor.sizes ();
50
- if (numDimensions == 3 ) {
51
- TORCH_CHECK (shape[2 ] == 3 , " Not a HWC tensor: " , shape);
52
- return hwcTensor.permute ({2 , 0 , 1 });
53
- } else if (numDimensions == 4 ) {
54
- TORCH_CHECK (shape[3 ] == 3 , " Not a NHWC tensor: " , shape);
55
- return hwcTensor.permute ({0 , 3 , 1 , 2 });
56
- } else {
57
- TORCH_CHECK (
58
- false , " Expected tensor with 3 or 4 dimensions, got " , numDimensions);
59
- }
60
- }
61
-
62
37
struct AVInput {
63
38
UniqueAVFormatContext formatContext;
64
39
std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -136,6 +111,31 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth(
136
111
137
112
} // namespace
138
113
114
+ // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
115
+ // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
116
+ // or 4D.
117
+ // Calling permute() is guaranteed to return a view as per the docs:
118
+ // https://pytorch.org/docs/stable/generated/torch.permute.html
119
+ torch::Tensor VideoDecoder::MaybePermuteHWC2CHW (
120
+ int streamIndex,
121
+ torch::Tensor& hwcTensor) {
122
+ if (streams_[streamIndex].options .dimensionOrder == " NHWC" ) {
123
+ return hwcTensor;
124
+ }
125
+ auto numDimensions = hwcTensor.dim ();
126
+ auto shape = hwcTensor.sizes ();
127
+ if (numDimensions == 3 ) {
128
+ TORCH_CHECK (shape[2 ] == 3 , " Not a HWC tensor: " , shape);
129
+ return hwcTensor.permute ({2 , 0 , 1 });
130
+ } else if (numDimensions == 4 ) {
131
+ TORCH_CHECK (shape[3 ] == 3 , " Not a NHWC tensor: " , shape);
132
+ return hwcTensor.permute ({0 , 3 , 1 , 2 });
133
+ } else {
134
+ TORCH_CHECK (
135
+ false , " Expected tensor with 3 or 4 dimensions, got " , numDimensions);
136
+ }
137
+ }
138
+
139
139
VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions (
140
140
const std::string& optionsString) {
141
141
std::vector<std::string> tokens =
@@ -929,14 +929,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
929
929
" Invalid color conversion library: " +
930
930
std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
931
931
}
932
- if (!preAllocatedOutputTensor.has_value ()) {
933
- // We only convert to CHW if a pre-allocated tensor wasn't passed. When a
934
- // pre-allocated tensor is passed, it's up to the caller (typically a
935
- // batch API) to do the conversion. This is more efficient as it allows
936
- // batch NHWC tensors to be permuted only once, instead of permuting HWC
937
- // tensors N times.
938
- output.frame = MaybePermuteHWC2CHW (streamInfo.options , output.frame );
939
- }
940
932
941
933
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
942
934
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
@@ -978,7 +970,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
978
970
return seconds >= frameStartTime && seconds < frameEndTime;
979
971
});
980
972
// Convert the frame to tensor.
981
- return convertAVFrameToDecodedOutput (rawOutput);
973
+ auto output = convertAVFrameToDecodedOutput (rawOutput);
974
+ output.frame = MaybePermuteHWC2CHW (output.streamIndex , output.frame );
975
+ return output;
982
976
}
983
977
984
978
void VideoDecoder::validateUserProvidedStreamIndex (uint64_t streamIndex) {
@@ -1015,6 +1009,16 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
1015
1009
int streamIndex,
1016
1010
int64_t frameIndex,
1017
1011
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1012
+ auto output = getFrameAtIndexInternal (
1013
+ streamIndex, frameIndex, preAllocatedOutputTensor);
1014
+ output.frame = MaybePermuteHWC2CHW (streamIndex, output.frame );
1015
+ return output;
1016
+ }
1017
+
1018
+ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal (
1019
+ int streamIndex,
1020
+ int64_t frameIndex,
1021
+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
1018
1022
validateUserProvidedStreamIndex (streamIndex);
1019
1023
validateScannedAllStreams (" getFrameAtIndex" );
1020
1024
@@ -1023,7 +1027,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
1023
1027
1024
1028
int64_t pts = stream.allFrames [frameIndex].pts ;
1025
1029
setCursorPtsInSeconds (ptsToSeconds (pts, stream.timeBase ));
1026
- return getNextDecodedOutputNoDemux (preAllocatedOutputTensor);
1030
+ return getNextFrameOutputNoDemuxInternal (preAllocatedOutputTensor);
1027
1031
}
1028
1032
1029
1033
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices (
@@ -1073,14 +1077,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
1073
1077
output.durationSeconds [indexInOutput] =
1074
1078
output.durationSeconds [previousIndexInOutput];
1075
1079
} else {
1076
- DecodedOutput singleOut = getFrameAtIndex (
1080
+ DecodedOutput singleOut = getFrameAtIndexInternal (
1077
1081
streamIndex, indexInVideo, output.frames [indexInOutput]);
1078
1082
output.ptsSeconds [indexInOutput] = singleOut.ptsSeconds ;
1079
1083
output.durationSeconds [indexInOutput] = singleOut.durationSeconds ;
1080
1084
}
1081
1085
previousIndexInVideo = indexInVideo;
1082
1086
}
1083
- output.frames = MaybePermuteHWC2CHW (options , output.frames );
1087
+ output.frames = MaybePermuteHWC2CHW (streamIndex , output.frames );
1084
1088
return output;
1085
1089
}
1086
1090
@@ -1150,11 +1154,12 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
1150
1154
BatchDecodedOutput output (numOutputFrames, options, streamMetadata);
1151
1155
1152
1156
for (int64_t i = start, f = 0 ; i < stop; i += step, ++f) {
1153
- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1157
+ DecodedOutput singleOut =
1158
+ getFrameAtIndexInternal (streamIndex, i, output.frames [f]);
1154
1159
output.ptsSeconds [f] = singleOut.ptsSeconds ;
1155
1160
output.durationSeconds [f] = singleOut.durationSeconds ;
1156
1161
}
1157
- output.frames = MaybePermuteHWC2CHW (options , output.frames );
1162
+ output.frames = MaybePermuteHWC2CHW (streamIndex , output.frames );
1158
1163
return output;
1159
1164
}
1160
1165
@@ -1207,7 +1212,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
1207
1212
// need this special case below.
1208
1213
if (startSeconds == stopSeconds) {
1209
1214
BatchDecodedOutput output (0 , options, streamMetadata);
1210
- output.frames = MaybePermuteHWC2CHW (options , output.frames );
1215
+ output.frames = MaybePermuteHWC2CHW (streamIndex , output.frames );
1211
1216
return output;
1212
1217
}
1213
1218
@@ -1243,11 +1248,12 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
1243
1248
int64_t numFrames = stopFrameIndex - startFrameIndex;
1244
1249
BatchDecodedOutput output (numFrames, options, streamMetadata);
1245
1250
for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
1246
- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1251
+ DecodedOutput singleOut =
1252
+ getFrameAtIndexInternal (streamIndex, i, output.frames [f]);
1247
1253
output.ptsSeconds [f] = singleOut.ptsSeconds ;
1248
1254
output.durationSeconds [f] = singleOut.durationSeconds ;
1249
1255
}
1250
- output.frames = MaybePermuteHWC2CHW (options , output.frames );
1256
+ output.frames = MaybePermuteHWC2CHW (streamIndex , output.frames );
1251
1257
1252
1258
return output;
1253
1259
}
@@ -1261,7 +1267,13 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
1261
1267
return rawOutput;
1262
1268
}
1263
1269
1264
- VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux (
1270
+ VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux () {
1271
+ auto output = getNextFrameOutputNoDemuxInternal ();
1272
+ output.frame = MaybePermuteHWC2CHW (output.streamIndex , output.frame );
1273
+ return output;
1274
+ }
1275
+
1276
+ VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal (
1265
1277
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1266
1278
auto rawOutput = getNextRawDecodedOutputNoDemux ();
1267
1279
return convertAVFrameToDecodedOutput (rawOutput, preAllocatedOutputTensor);
@@ -1283,7 +1295,7 @@ double VideoDecoder::getPtsSecondsForFrame(
1283
1295
int streamIndex,
1284
1296
int64_t frameIndex) {
1285
1297
validateUserProvidedStreamIndex (streamIndex);
1286
- validateScannedAllStreams (" getFrameAtIndex " );
1298
+ validateScannedAllStreams (" getPtsSecondsForFrame " );
1287
1299
1288
1300
const auto & stream = streams_[streamIndex];
1289
1301
validateFrameIndex (stream, frameIndex);
0 commit comments