From 9d6bcfffc60ab9f2f8f1275410c6d061f512f889 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 15 May 2025 11:56:23 +0100 Subject: [PATCH 1/6] WIP --- src/torchcodec/_core/SingleStreamDecoder.cpp | 38 ++++++++++++++------ 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index e2c55ef2..4214387e 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -29,6 +29,13 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { return static_cast(std::round(seconds * timeBase.den)); } +int64_t getPtsOrDts(const UniqueAVFrame& avFrame) { + return avFrame->pts == INT64_MIN? avFrame->pkt_dts : avFrame->pts; +} +int64_t getPtsOrDts(ReferenceAVPacket& packet){ + return packet->pts == INT64_MIN? packet->dts : packet->pts; +} + } // namespace // -------------------------------------------------------------------------- @@ -225,16 +232,16 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { int streamIndex = packet->stream_index; auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; streamMetadata.minPtsFromScan = std::min( - streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts); + streamMetadata.minPtsFromScan.value_or(INT64_MAX), getPtsOrDts(packet)); streamMetadata.maxPtsFromScan = std::max( streamMetadata.maxPtsFromScan.value_or(INT64_MIN), - packet->pts + packet->duration); + getPtsOrDts(packet) + packet->duration); streamMetadata.numFramesFromScan = streamMetadata.numFramesFromScan.value_or(0) + 1; // Note that we set the other value in this struct, nextPts, only after // we have scanned all packets and sorted by pts. - FrameInfo frameInfo = {packet->pts}; + FrameInfo frameInfo = {getPtsOrDts(packet)}; if (packet->flags & AV_PKT_FLAG_KEY) { frameInfo.isKeyFrame = true; streamInfos_[streamIndex].keyFrames.push_back(frameInfo); @@ -496,7 +503,7 @@ FrameOutput SingleStreamDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { validateActiveStream(); UniqueAVFrame avFrame = decodeAVFrame( - [this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; }); + [this](const UniqueAVFrame& avFrame) { return getPtsOrDts(avFrame) >= cursor_; }); return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor); } @@ -616,6 +623,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); + printf("seconds=%f\n", seconds); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase); @@ -627,14 +635,24 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { // don't cache it locally, we have to rewind back. seconds = frameStartTime; } + printf("seconds=%f\n", seconds); + printf("frameStartTime=%f\n", frameStartTime); + printf("frameEndTime=%f\n", frameEndTime); + printf("TimeBase: %d %d\n", streamInfo.timeBase.num, streamInfo.timeBase.den); + printf("In decoding loop\n"); setCursorPtsInSeconds(seconds); UniqueAVFrame avFrame = decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) { StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); + double frameStartTime = ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase); double frameEndTime = ptsToSeconds( - avFrame->pts + getDuration(avFrame), streamInfo.timeBase); + getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase); + printf("frame pts=%ld\n", avFrame->pts); + printf("frame pkt_dts=%ld\n", avFrame->pkt_dts); + printf("frameStartTime=%f\n", frameStartTime); + printf("frameEndTime=%f\n", frameEndTime); + printf("\n"); if (frameStartTime > seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() @@ -861,8 +879,8 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( try { UniqueAVFrame avFrame = decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) { - return startPts < avFrame->pts + getDuration(avFrame) && - stopPts > avFrame->pts; + return startPts < getPtsOrDts(avFrame) + getDuration(avFrame) && + stopPts > getPtsOrDts(avFrame); }); auto frameOutput = convertAVFrameToFrameOutput(avFrame); if (!firstFramePtsSeconds.has_value()) { @@ -1132,7 +1150,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // haven't received as frames. Eventually we will either hit AVERROR_EOF from // av_receive_frame() or the user will have seeked to a different location in // the file and that will flush the decoder. - streamInfo.lastDecodedAvFramePts = avFrame->pts; + streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame); streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame); return avFrame; @@ -1149,7 +1167,7 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( FrameOutput frameOutput; auto& streamInfo = streamInfos_[activeStreamIndex_]; frameOutput.ptsSeconds = ptsToSeconds( - avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base); + getPtsOrDts(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); From a4a034a84e4544d5933a49048daa2c8d88e5f2d8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 15 May 2025 12:48:48 +0100 Subject: [PATCH 2/6] Fix pts <-> seconds conversions --- src/torchcodec/_core/SingleStreamDecoder.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index e2c55ef2..42989340 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -17,16 +17,13 @@ namespace facebook::torchcodec { namespace { -double ptsToSeconds(int64_t pts, int den) { - return static_cast(pts) / den; -} - double ptsToSeconds(int64_t pts, const AVRational& timeBase) { - return ptsToSeconds(pts, timeBase.den); + return static_cast(pts) * timeBase.num / timeBase.den; } int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { - return static_cast(std::round(seconds * timeBase.den)); + return static_cast( + std::round(seconds * timeBase.den / timeBase.num)); } } // namespace @@ -152,7 +149,7 @@ void SingleStreamDecoder::initializeDecoder() { if (formatContext_->duration > 0) { containerMetadata_.durationSeconds = - ptsToSeconds(formatContext_->duration, AV_TIME_BASE); + ptsToSeconds(formatContext_->duration, AV_TIME_BASE_Q); } if (formatContext_->bit_rate > 0) { From ce04b3c97b85bb3b3986934874ccc8443af2ce35 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 15 May 2025 12:59:32 +0100 Subject: [PATCH 3/6] Fallback to DTS when PTS info is missing --- src/torchcodec/_core/SingleStreamDecoder.cpp | 35 +++++++++----------- src/torchcodec/_core/_metadata.py | 3 ++ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index fcc98be6..b131e972 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -26,11 +26,16 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { std::round(seconds * timeBase.den / timeBase.num)); } -int64_t getPtsOrDts(const UniqueAVFrame& avFrame) { - return avFrame->pts == INT64_MIN? avFrame->pkt_dts : avFrame->pts; +// Some videos aren't properly encoded and do not specify pts values (for +// packets, and thus for frames). Unset values correspond to INT64_MIN. When +// that happens, we fall-back to the dts value which hopefully exists and is +// correct. +int64_t getPtsOrDts(ReferenceAVPacket& packet) { + return packet->pts == INT64_MIN ? packet->dts : packet->pts; } -int64_t getPtsOrDts(ReferenceAVPacket& packet){ - return packet->pts == INT64_MIN? packet->dts : packet->pts; + +int64_t getPtsOrDts(const UniqueAVFrame& avFrame) { + return avFrame->pts == INT64_MIN ? avFrame->pkt_dts : avFrame->pts; } } // namespace @@ -499,8 +504,9 @@ FrameOutput SingleStreamDecoder::getNextFrame() { FrameOutput SingleStreamDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { validateActiveStream(); - UniqueAVFrame avFrame = decodeAVFrame( - [this](const UniqueAVFrame& avFrame) { return getPtsOrDts(avFrame) >= cursor_; }); + UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) { + return getPtsOrDts(avFrame) >= cursor_; + }); return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor); } @@ -620,7 +626,6 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); - printf("seconds=%f\n", seconds); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase); @@ -632,24 +637,15 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { // don't cache it locally, we have to rewind back. seconds = frameStartTime; } - printf("seconds=%f\n", seconds); - printf("frameStartTime=%f\n", frameStartTime); - printf("frameEndTime=%f\n", frameEndTime); - printf("TimeBase: %d %d\n", streamInfo.timeBase.num, streamInfo.timeBase.den); - printf("In decoding loop\n"); setCursorPtsInSeconds(seconds); UniqueAVFrame avFrame = decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) { StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - double frameStartTime = ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase); + double frameStartTime = + ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase); double frameEndTime = ptsToSeconds( getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase); - printf("frame pts=%ld\n", avFrame->pts); - printf("frame pkt_dts=%ld\n", avFrame->pkt_dts); - printf("frameStartTime=%f\n", frameStartTime); - printf("frameEndTime=%f\n", frameEndTime); - printf("\n"); if (frameStartTime > seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() @@ -1164,7 +1160,8 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( FrameOutput frameOutput; auto& streamInfo = streamInfos_[activeStreamIndex_]; frameOutput.ptsSeconds = ptsToSeconds( - getPtsOrDts(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); + getPtsOrDts(avFrame), + formatContext_->streams[activeStreamIndex_]->time_base); frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 93c4448c..58c16366 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -142,6 +142,9 @@ def average_fps(self) -> Optional[float]: self.end_stream_seconds_from_content is None or self.begin_stream_seconds_from_content is None or self.num_frames is None + # Should never happen, but prevents ZeroDivisionError: + or self.end_stream_seconds_from_content + == self.begin_stream_seconds_from_content ): return self.average_fps_from_header return self.num_frames / ( From 06cd220261b7f829aa43a2bc6f35cc70363a37c9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 15 May 2025 13:07:15 +0100 Subject: [PATCH 4/6] Fix compile issue --- src/torchcodec/_core/SingleStreamDecoder.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 42989340..6299e319 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -148,8 +148,9 @@ void SingleStreamDecoder::initializeDecoder() { } if (formatContext_->duration > 0) { + AVRational defaultTimeBase{1, AV_TIME_BASE}; containerMetadata_.durationSeconds = - ptsToSeconds(formatContext_->duration, AV_TIME_BASE_Q); + ptsToSeconds(formatContext_->duration, defaultTimeBase); } if (formatContext_->bit_rate > 0) { From 8bb3c0849a73f06d330708c8c03d79b04bcad759 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 15 May 2025 13:18:05 +0100 Subject: [PATCH 5/6] Comment --- src/torchcodec/_core/SingleStreamDecoder.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 66615a3d..770f78bd 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -26,10 +26,12 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { std::round(seconds * timeBase.den / timeBase.num)); } -// Some videos aren't properly encoded and do not specify pts values (for -// packets, and thus for frames). Unset values correspond to INT64_MIN. When -// that happens, we fall-back to the dts value which hopefully exists and is -// correct. +// Some videos aren't properly encoded and do not specify pts values for +// packets, and thus for frames. Unset values correspond to INT64_MIN. When that +// happens, we fallback to the dts value which hopefully exists and is correct. +// Accessing AVFrames and AVPackets's pts values should **always** go through +// the helpers below. Then, the "pts" fields in our structs like FrameInfo.pts +// should be interpreted as "pts if it exists, dts otherwise". int64_t getPtsOrDts(ReferenceAVPacket& packet) { return packet->pts == INT64_MIN ? packet->dts : packet->pts; } From 9a2c5be59454bce49dc1fb4e2a8274ab8faeabae Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 15 May 2025 13:33:20 +0100 Subject: [PATCH 6/6] Add test --- test/test_decoders.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/test/test_decoders.py b/test/test_decoders.py index ddd35ff3..133b3cb6 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -986,6 +986,47 @@ def get_some_frames(decoder): assert_frames_equal(ref_frame3, frames[1].data) assert_frames_equal(ref_frame5, frames[2].data) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_pts_to_dts_fallback(self, seek_mode): + # Non-regression test for + # https://github.com/pytorch/torchcodec/issues/677 and + # https://github.com/pytorch/torchcodec/issues/676. + # More accurately, this is a non-regression test for videos which do + # *not* specify pts values (all pts values are N/A and set to + # INT64_MIN), but specify *dts* value - which we fallback to. + # + # The test video we have is from + # https://huggingface.co/datasets/raushan-testing-hf/videos-test/blob/main/sample_video_2.avi + # We can't check it into the repo due to potential licensing issues, so + # we have to unconditionally skip this test.# + # TODO: encode a video with no pts values to unskip this test. Couldn't + # find a way to do that with FFmpeg's CLI, but this should be doable + # once we have our own video encoder. + pytest.skip(reason="TODO: Need video with no pts values.") + + path = "/home/nicolashug/Downloads/sample_video_2.avi" + decoder = VideoDecoder(path, seek_mode=seek_mode) + metadata = decoder.metadata + + assert metadata.average_fps == pytest.approx(29.916667) + assert metadata.duration_seconds_from_header == 9.02507 + assert metadata.duration_seconds == 9.02507 + assert metadata.begin_stream_seconds_from_content == ( + None if seek_mode == "approximate" else 0 + ) + assert metadata.end_stream_seconds_from_content == ( + None if seek_mode == "approximate" else 9.02507 + ) + + assert decoder[0].shape == (3, 240, 320) + decoder[10].shape == (3, 240, 320) + decoder.get_frame_at(2).data.shape == (3, 240, 320) + decoder.get_frames_at([2, 10]).data.shape == (2, 3, 240, 320) + decoder.get_frame_played_at(9).data.shape == (3, 240, 320) + decoder.get_frames_played_at([2, 4]).data.shape == (2, 3, 240, 320) + with pytest.raises(AssertionError, match="not equal"): + torch.testing.assert_close(decoder[0], decoder[10]) + class TestAudioDecoder: @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))