diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 6299e319..770f78bd 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -26,6 +26,20 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { std::round(seconds * timeBase.den / timeBase.num)); } +// Some videos aren't properly encoded and do not specify pts values for +// packets, and thus for frames. Unset values correspond to INT64_MIN. When that +// happens, we fallback to the dts value which hopefully exists and is correct. +// Accessing AVFrames and AVPackets's pts values should **always** go through +// the helpers below. Then, the "pts" fields in our structs like FrameInfo.pts +// should be interpreted as "pts if it exists, dts otherwise". +int64_t getPtsOrDts(ReferenceAVPacket& packet) { + return packet->pts == INT64_MIN ? packet->dts : packet->pts; +} + +int64_t getPtsOrDts(const UniqueAVFrame& avFrame) { + return avFrame->pts == INT64_MIN ? avFrame->pkt_dts : avFrame->pts; +} + } // namespace // -------------------------------------------------------------------------- @@ -223,16 +237,16 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { int streamIndex = packet->stream_index; auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; streamMetadata.minPtsFromScan = std::min( - streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts); + streamMetadata.minPtsFromScan.value_or(INT64_MAX), getPtsOrDts(packet)); streamMetadata.maxPtsFromScan = std::max( streamMetadata.maxPtsFromScan.value_or(INT64_MIN), - packet->pts + packet->duration); + getPtsOrDts(packet) + packet->duration); streamMetadata.numFramesFromScan = streamMetadata.numFramesFromScan.value_or(0) + 1; // Note that we set the other value in this struct, nextPts, only after // we have scanned all packets and sorted by pts. - FrameInfo frameInfo = {packet->pts}; + FrameInfo frameInfo = {getPtsOrDts(packet)}; if (packet->flags & AV_PKT_FLAG_KEY) { frameInfo.isKeyFrame = true; streamInfos_[streamIndex].keyFrames.push_back(frameInfo); @@ -493,8 +507,9 @@ FrameOutput SingleStreamDecoder::getNextFrame() { FrameOutput SingleStreamDecoder::getNextFrameInternal( std::optional<torch::Tensor> preAllocatedOutputTensor) { validateActiveStream(); - UniqueAVFrame avFrame = decodeAVFrame( - [this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; }); + UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) { + return getPtsOrDts(avFrame) >= cursor_; + }); return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor); } @@ -630,9 +645,10 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { UniqueAVFrame avFrame = decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) { StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); + double frameStartTime = + ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase); double frameEndTime = ptsToSeconds( - avFrame->pts + getDuration(avFrame), streamInfo.timeBase); + getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase); if (frameStartTime > seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() @@ -859,8 +875,8 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( try { UniqueAVFrame avFrame = decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) { - return startPts < avFrame->pts + getDuration(avFrame) && - stopPts > avFrame->pts; + return startPts < getPtsOrDts(avFrame) + getDuration(avFrame) && + stopPts > getPtsOrDts(avFrame); }); auto frameOutput = convertAVFrameToFrameOutput(avFrame); if (!firstFramePtsSeconds.has_value()) { @@ -1130,7 +1146,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // haven't received as frames. Eventually we will either hit AVERROR_EOF from // av_receive_frame() or the user will have seeked to a different location in // the file and that will flush the decoder. - streamInfo.lastDecodedAvFramePts = avFrame->pts; + streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame); streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame); return avFrame; @@ -1147,7 +1163,8 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( FrameOutput frameOutput; auto& streamInfo = streamInfos_[activeStreamIndex_]; frameOutput.ptsSeconds = ptsToSeconds( - avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base); + getPtsOrDts(avFrame), + formatContext_->streams[activeStreamIndex_]->time_base); frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 93c4448c..58c16366 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -142,6 +142,9 @@ def average_fps(self) -> Optional[float]: self.end_stream_seconds_from_content is None or self.begin_stream_seconds_from_content is None or self.num_frames is None + # Should never happen, but prevents ZeroDivisionError: + or self.end_stream_seconds_from_content + == self.begin_stream_seconds_from_content ): return self.average_fps_from_header return self.num_frames / ( diff --git a/test/test_decoders.py b/test/test_decoders.py index ddd35ff3..133b3cb6 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -986,6 +986,47 @@ def get_some_frames(decoder): assert_frames_equal(ref_frame3, frames[1].data) assert_frames_equal(ref_frame5, frames[2].data) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_pts_to_dts_fallback(self, seek_mode): + # Non-regression test for + # https://github.com/pytorch/torchcodec/issues/677 and + # https://github.com/pytorch/torchcodec/issues/676. + # More accurately, this is a non-regression test for videos which do + # *not* specify pts values (all pts values are N/A and set to + # INT64_MIN), but specify *dts* value - which we fallback to. + # + # The test video we have is from + # https://huggingface.co/datasets/raushan-testing-hf/videos-test/blob/main/sample_video_2.avi + # We can't check it into the repo due to potential licensing issues, so + # we have to unconditionally skip this test.# + # TODO: encode a video with no pts values to unskip this test. Couldn't + # find a way to do that with FFmpeg's CLI, but this should be doable + # once we have our own video encoder. + pytest.skip(reason="TODO: Need video with no pts values.") + + path = "/home/nicolashug/Downloads/sample_video_2.avi" + decoder = VideoDecoder(path, seek_mode=seek_mode) + metadata = decoder.metadata + + assert metadata.average_fps == pytest.approx(29.916667) + assert metadata.duration_seconds_from_header == 9.02507 + assert metadata.duration_seconds == 9.02507 + assert metadata.begin_stream_seconds_from_content == ( + None if seek_mode == "approximate" else 0 + ) + assert metadata.end_stream_seconds_from_content == ( + None if seek_mode == "approximate" else 9.02507 + ) + + assert decoder[0].shape == (3, 240, 320) + decoder[10].shape == (3, 240, 320) + decoder.get_frame_at(2).data.shape == (3, 240, 320) + decoder.get_frames_at([2, 10]).data.shape == (2, 3, 240, 320) + decoder.get_frame_played_at(9).data.shape == (3, 240, 320) + decoder.get_frames_played_at([2, 4]).data.shape == (2, 3, 240, 320) + with pytest.raises(AssertionError, match="not equal"): + torch.testing.assert_close(decoder[0], decoder[10]) + class TestAudioDecoder: @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))