@@ -191,14 +191,14 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
191
191
int64_t numFrames,
192
192
const VideoStreamDecoderOptions& options,
193
193
const StreamMetadata& metadata)
194
- : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
195
- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
196
- frames (torch::empty(
194
+ : frames(torch::empty(
197
195
{numFrames,
198
196
options.height .value_or (*metadata.height ),
199
197
options.width .value_or (*metadata.width ),
200
198
3 },
201
- {torch::kUInt8 })) {}
199
+ {torch::kUInt8 })),
200
+ ptsSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
201
+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {}
202
202
203
203
VideoDecoder::VideoDecoder () {}
204
204
@@ -1017,24 +1017,57 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
1017
1017
validateUserProvidedStreamIndex (streamIndex);
1018
1018
validateScannedAllStreams (" getFramesAtIndices" );
1019
1019
1020
+ auto indicesAreSorted =
1021
+ std::is_sorted (frameIndices.begin (), frameIndices.end ());
1022
+
1023
+ std::vector<size_t > argsort;
1024
+ if (!indicesAreSorted) {
1025
+ // if frameIndices is [13, 10, 12, 11]
1026
+ // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1027
+ // to use to decode the frames
1028
+ // and argsort is [ 1, 3, 2, 0]
1029
+ argsort.resize (frameIndices.size ());
1030
+ for (size_t i = 0 ; i < argsort.size (); ++i) {
1031
+ argsort[i] = i;
1032
+ }
1033
+ std::sort (
1034
+ argsort.begin (), argsort.end (), [&frameIndices](size_t a, size_t b) {
1035
+ return frameIndices[a] < frameIndices[b];
1036
+ });
1037
+ }
1038
+
1020
1039
const auto & streamMetadata = containerMetadata_.streams [streamIndex];
1021
1040
const auto & stream = streams_[streamIndex];
1022
1041
const auto & options = stream.options ;
1023
1042
BatchDecodedOutput output (frameIndices.size (), options, streamMetadata);
1024
1043
1044
+ auto previousIndexInVideo = -1 ;
1025
1045
for (auto f = 0 ; f < frameIndices.size (); ++f) {
1026
- auto frameIndex = frameIndices[f];
1027
- if (frameIndex < 0 || frameIndex >= stream.allFrames .size ()) {
1046
+ auto indexInOutput = indicesAreSorted ? f : argsort[f];
1047
+ auto indexInVideo = frameIndices[indexInOutput];
1048
+ if (indexInVideo < 0 || indexInVideo >= stream.allFrames .size ()) {
1028
1049
throw std::runtime_error (
1029
- " Invalid frame index=" + std::to_string (frameIndex ));
1050
+ " Invalid frame index=" + std::to_string (indexInVideo ));
1030
1051
}
1031
- DecodedOutput singleOut =
1032
- getFrameAtIndex (streamIndex, frameIndex, output.frames [f]);
1033
- if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1034
- output.frames [f] = singleOut.frame ;
1052
+ if ((f > 0 ) && (indexInVideo == previousIndexInVideo)) {
1053
+ // Avoid decoding the same frame twice
1054
+ auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1 ];
1055
+ output.frames [indexInOutput].copy_ (output.frames [previousIndexInOutput]);
1056
+ output.ptsSeconds [indexInOutput] =
1057
+ output.ptsSeconds [previousIndexInOutput];
1058
+ output.durationSeconds [indexInOutput] =
1059
+ output.durationSeconds [previousIndexInOutput];
1060
+ } else {
1061
+ DecodedOutput singleOut = getFrameAtIndex (
1062
+ streamIndex, indexInVideo, output.frames [indexInOutput]);
1063
+ if (options.colorConversionLibrary ==
1064
+ ColorConversionLibrary::FILTERGRAPH) {
1065
+ output.frames [indexInOutput] = singleOut.frame ;
1066
+ }
1067
+ output.ptsSeconds [indexInOutput] = singleOut.ptsSeconds ;
1068
+ output.durationSeconds [indexInOutput] = singleOut.durationSeconds ;
1035
1069
}
1036
- // Note that for now we ignore the pts and duration parts of the output,
1037
- // because they're never used in any caller.
1070
+ previousIndexInVideo = indexInVideo;
1038
1071
}
1039
1072
output.frames = MaybePermuteHWC2CHW (options, output.frames );
1040
1073
return output;
0 commit comments