Skip to content

Fallback to DTS if PTS info doesn't exist #683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
std::round(seconds * timeBase.den / timeBase.num));
}

// Some videos aren't properly encoded and do not specify pts values for
// packets, and thus for frames. Unset values correspond to INT64_MIN. When that
// happens, we fallback to the dts value which hopefully exists and is correct.
// Accessing AVFrames and AVPackets's pts values should **always** go through
// the helpers below. Then, the "pts" fields in our structs like FrameInfo.pts
// should be interpreted as "pts if it exists, dts otherwise".
int64_t getPtsOrDts(ReferenceAVPacket& packet) {
return packet->pts == INT64_MIN ? packet->dts : packet->pts;
}

int64_t getPtsOrDts(const UniqueAVFrame& avFrame) {
return avFrame->pts == INT64_MIN ? avFrame->pkt_dts : avFrame->pts;
}

} // namespace

// --------------------------------------------------------------------------
Expand Down Expand Up @@ -223,16 +237,16 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
int streamIndex = packet->stream_index;
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
streamMetadata.minPtsFromScan = std::min(
streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts);
streamMetadata.minPtsFromScan.value_or(INT64_MAX), getPtsOrDts(packet));
streamMetadata.maxPtsFromScan = std::max(
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
packet->pts + packet->duration);
getPtsOrDts(packet) + packet->duration);
streamMetadata.numFramesFromScan =
streamMetadata.numFramesFromScan.value_or(0) + 1;

// Note that we set the other value in this struct, nextPts, only after
// we have scanned all packets and sorted by pts.
FrameInfo frameInfo = {packet->pts};
FrameInfo frameInfo = {getPtsOrDts(packet)};
if (packet->flags & AV_PKT_FLAG_KEY) {
frameInfo.isKeyFrame = true;
streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
Expand Down Expand Up @@ -493,8 +507,9 @@ FrameOutput SingleStreamDecoder::getNextFrame() {
FrameOutput SingleStreamDecoder::getNextFrameInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
validateActiveStream();
UniqueAVFrame avFrame = decodeAVFrame(
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) {
return getPtsOrDts(avFrame) >= cursor_;
});
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
}

Expand Down Expand Up @@ -630,9 +645,10 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
UniqueAVFrame avFrame =
decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
double frameStartTime =
ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase);
double frameEndTime = ptsToSeconds(
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase);
if (frameStartTime > seconds) {
// FFMPEG seeked past the frame we are looking for even though we
// set max_ts to be our needed timestamp in avformat_seek_file()
Expand Down Expand Up @@ -859,8 +875,8 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
try {
UniqueAVFrame avFrame =
decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) {
return startPts < avFrame->pts + getDuration(avFrame) &&
stopPts > avFrame->pts;
return startPts < getPtsOrDts(avFrame) + getDuration(avFrame) &&
stopPts > getPtsOrDts(avFrame);
});
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
if (!firstFramePtsSeconds.has_value()) {
Expand Down Expand Up @@ -1130,7 +1146,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
// av_receive_frame() or the user will have seeked to a different location in
// the file and that will flush the decoder.
streamInfo.lastDecodedAvFramePts = avFrame->pts;
streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame);
streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);

return avFrame;
Expand All @@ -1147,7 +1163,8 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
FrameOutput frameOutput;
auto& streamInfo = streamInfos_[activeStreamIndex_];
frameOutput.ptsSeconds = ptsToSeconds(
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
getPtsOrDts(avFrame),
formatContext_->streams[activeStreamIndex_]->time_base);
frameOutput.durationSeconds = ptsToSeconds(
getDuration(avFrame),
formatContext_->streams[activeStreamIndex_]->time_base);
Expand Down
3 changes: 3 additions & 0 deletions src/torchcodec/_core/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def average_fps(self) -> Optional[float]:
self.end_stream_seconds_from_content is None
or self.begin_stream_seconds_from_content is None
or self.num_frames is None
# Should never happen, but prevents ZeroDivisionError:
or self.end_stream_seconds_from_content
== self.begin_stream_seconds_from_content
Comment on lines +145 to +147
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't actually needed anymore, i.e. the dts fallback already fixed the metadata issue. But it's probably safer to have this regardless?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's keep it. I think a single frame video is allowed, and we would encounter this problem with such a video.

):
return self.average_fps_from_header
return self.num_frames / (
Expand Down
41 changes: 41 additions & 0 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,47 @@ def get_some_frames(decoder):
assert_frames_equal(ref_frame3, frames[1].data)
assert_frames_equal(ref_frame5, frames[2].data)

@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_pts_to_dts_fallback(self, seek_mode):
# Non-regression test for
# https://github.com/pytorch/torchcodec/issues/677 and
# https://github.com/pytorch/torchcodec/issues/676.
# More accurately, this is a non-regression test for videos which do
# *not* specify pts values (all pts values are N/A and set to
# INT64_MIN), but specify *dts* value - which we fallback to.
#
# The test video we have is from
# https://huggingface.co/datasets/raushan-testing-hf/videos-test/blob/main/sample_video_2.avi
# We can't check it into the repo due to potential licensing issues, so
# we have to unconditionally skip this test.#
# TODO: encode a video with no pts values to unskip this test. Couldn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create an issue?

# find a way to do that with FFmpeg's CLI, but this should be doable
# once we have our own video encoder.
pytest.skip(reason="TODO: Need video with no pts values.")

path = "/home/nicolashug/Downloads/sample_video_2.avi"
decoder = VideoDecoder(path, seek_mode=seek_mode)
metadata = decoder.metadata

assert metadata.average_fps == pytest.approx(29.916667)
assert metadata.duration_seconds_from_header == 9.02507
assert metadata.duration_seconds == 9.02507
assert metadata.begin_stream_seconds_from_content == (
None if seek_mode == "approximate" else 0
)
assert metadata.end_stream_seconds_from_content == (
None if seek_mode == "approximate" else 9.02507
)

assert decoder[0].shape == (3, 240, 320)
decoder[10].shape == (3, 240, 320)
decoder.get_frame_at(2).data.shape == (3, 240, 320)
decoder.get_frames_at([2, 10]).data.shape == (2, 3, 240, 320)
decoder.get_frame_played_at(9).data.shape == (3, 240, 320)
decoder.get_frames_played_at([2, 4]).data.shape == (2, 3, 240, 320)
with pytest.raises(AssertionError, match="not equal"):
torch.testing.assert_close(decoder[0], decoder[10])


class TestAudioDecoder:
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))
Expand Down
Loading