Skip to content

Commit b979201

Browse files
authored
Ensure we don't read invalid memory when seeking to pts before the first frame (pytorch#331)
1 parent 2640fa9 commit b979201

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
3434
return ptsToSeconds(pts, timeBase.den);
3535
}
3636

37+
int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
38+
return static_cast<int64_t>(std::round(seconds * timeBase.den));
39+
}
40+
3741
struct AVInput {
3842
UniqueAVFormatContext formatContext;
3943
std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -663,7 +667,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
663667
for (int streamIndex : activeStreamIndices_) {
664668
StreamInfo& streamInfo = streams_[streamIndex];
665669
// clang-format off: clang format clashes
666-
streamInfo.discardFramesBeforePts = *maybeDesiredPts_ * streamInfo.timeBase.den;
670+
streamInfo.discardFramesBeforePts = secondsToClosestPts(*maybeDesiredPts_, streamInfo.timeBase);
667671
// clang-format on
668672
}
669673

@@ -686,16 +690,18 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
686690
}
687691
int firstActiveStreamIndex = *activeStreamIndices_.begin();
688692
const auto& firstStreamInfo = streams_[firstActiveStreamIndex];
689-
int64_t desiredPts = *maybeDesiredPts_ * firstStreamInfo.timeBase.den;
693+
int64_t desiredPts =
694+
secondsToClosestPts(*maybeDesiredPts_, firstStreamInfo.timeBase);
690695

691696
// For some encodings like H265, FFMPEG sometimes seeks past the point we
692697
// set as the max_ts. So we use our own index to give it the exact pts of
693698
// the key frame that we want to seek to.
694699
// See https://github.com/pytorch/torchcodec/issues/179 for more details.
695700
// See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
696701
if (!firstStreamInfo.keyFrames.empty()) {
697-
int desiredKeyFrameIndex =
698-
getKeyFrameIndexForPts(firstStreamInfo, desiredPts);
702+
int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex(
703+
firstStreamInfo.keyFrames, desiredPts);
704+
desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0);
699705
desiredPts = firstStreamInfo.keyFrames[desiredKeyFrameIndex].pts;
700706
}
701707

test/decoders/test_video_decoder_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ def test_seek_and_next(self, device):
8484
)
8585
frame_compare_function(frame_time6, reference_frame_time6.to(device))
8686

87+
@pytest.mark.parametrize("device", cpu_and_cuda())
88+
def test_seek_to_negative_pts(self, device):
89+
decoder = create_from_file(str(NASA_VIDEO.path))
90+
scan_all_streams_to_update_metadata(decoder)
91+
add_video_stream(decoder, device=device)
92+
frame_compare_function = get_frame_compare_function(device)
93+
frame0, _, _ = get_next_frame(decoder)
94+
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
95+
frame_compare_function(frame0, reference_frame0.to(device))
96+
97+
seek_to_pts(decoder, -1e-4)
98+
frame0, _, _ = get_next_frame(decoder)
99+
frame_compare_function(frame0, reference_frame0.to(device))
100+
87101
@pytest.mark.parametrize("device", cpu_and_cuda())
88102
def test_get_frame_at_pts(self, device):
89103
decoder = create_from_file(str(NASA_VIDEO.path))

0 commit comments

Comments
 (0)