Skip to content

Commit a3867d1

Browse files
authored
Fallback to DTS if PTS info doesn't exist (#683)
1 parent 5013118 commit a3867d1

File tree

3 files changed

+72
-11
lines changed

3 files changed

+72
-11
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

+28-11
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
2626
std::round(seconds * timeBase.den / timeBase.num));
2727
}
2828

29+
// Some videos aren't properly encoded and do not specify pts values for
30+
// packets, and thus for frames. Unset values correspond to INT64_MIN. When that
31+
// happens, we fallback to the dts value which hopefully exists and is correct.
32+
// Accessing AVFrames and AVPackets's pts values should **always** go through
33+
// the helpers below. Then, the "pts" fields in our structs like FrameInfo.pts
34+
// should be interpreted as "pts if it exists, dts otherwise".
35+
int64_t getPtsOrDts(ReferenceAVPacket& packet) {
36+
return packet->pts == INT64_MIN ? packet->dts : packet->pts;
37+
}
38+
39+
int64_t getPtsOrDts(const UniqueAVFrame& avFrame) {
40+
return avFrame->pts == INT64_MIN ? avFrame->pkt_dts : avFrame->pts;
41+
}
42+
2943
} // namespace
3044

3145
// --------------------------------------------------------------------------
@@ -223,16 +237,16 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
223237
int streamIndex = packet->stream_index;
224238
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
225239
streamMetadata.minPtsFromScan = std::min(
226-
streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts);
240+
streamMetadata.minPtsFromScan.value_or(INT64_MAX), getPtsOrDts(packet));
227241
streamMetadata.maxPtsFromScan = std::max(
228242
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
229-
packet->pts + packet->duration);
243+
getPtsOrDts(packet) + packet->duration);
230244
streamMetadata.numFramesFromScan =
231245
streamMetadata.numFramesFromScan.value_or(0) + 1;
232246

233247
// Note that we set the other value in this struct, nextPts, only after
234248
// we have scanned all packets and sorted by pts.
235-
FrameInfo frameInfo = {packet->pts};
249+
FrameInfo frameInfo = {getPtsOrDts(packet)};
236250
if (packet->flags & AV_PKT_FLAG_KEY) {
237251
frameInfo.isKeyFrame = true;
238252
streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
@@ -493,8 +507,9 @@ FrameOutput SingleStreamDecoder::getNextFrame() {
493507
FrameOutput SingleStreamDecoder::getNextFrameInternal(
494508
std::optional<torch::Tensor> preAllocatedOutputTensor) {
495509
validateActiveStream();
496-
UniqueAVFrame avFrame = decodeAVFrame(
497-
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
510+
UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) {
511+
return getPtsOrDts(avFrame) >= cursor_;
512+
});
498513
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
499514
}
500515

@@ -630,9 +645,10 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
630645
UniqueAVFrame avFrame =
631646
decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
632647
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
633-
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
648+
double frameStartTime =
649+
ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase);
634650
double frameEndTime = ptsToSeconds(
635-
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
651+
getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase);
636652
if (frameStartTime > seconds) {
637653
// FFMPEG seeked past the frame we are looking for even though we
638654
// set max_ts to be our needed timestamp in avformat_seek_file()
@@ -859,8 +875,8 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
859875
try {
860876
UniqueAVFrame avFrame =
861877
decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) {
862-
return startPts < avFrame->pts + getDuration(avFrame) &&
863-
stopPts > avFrame->pts;
878+
return startPts < getPtsOrDts(avFrame) + getDuration(avFrame) &&
879+
stopPts > getPtsOrDts(avFrame);
864880
});
865881
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
866882
if (!firstFramePtsSeconds.has_value()) {
@@ -1130,7 +1146,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11301146
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
11311147
// av_receive_frame() or the user will have seeked to a different location in
11321148
// the file and that will flush the decoder.
1133-
streamInfo.lastDecodedAvFramePts = avFrame->pts;
1149+
streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame);
11341150
streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
11351151

11361152
return avFrame;
@@ -1147,7 +1163,8 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
11471163
FrameOutput frameOutput;
11481164
auto& streamInfo = streamInfos_[activeStreamIndex_];
11491165
frameOutput.ptsSeconds = ptsToSeconds(
1150-
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
1166+
getPtsOrDts(avFrame),
1167+
formatContext_->streams[activeStreamIndex_]->time_base);
11511168
frameOutput.durationSeconds = ptsToSeconds(
11521169
getDuration(avFrame),
11531170
formatContext_->streams[activeStreamIndex_]->time_base);

src/torchcodec/_core/_metadata.py

+3
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def average_fps(self) -> Optional[float]:
142142
self.end_stream_seconds_from_content is None
143143
or self.begin_stream_seconds_from_content is None
144144
or self.num_frames is None
145+
# Should never happen, but prevents ZeroDivisionError:
146+
or self.end_stream_seconds_from_content
147+
== self.begin_stream_seconds_from_content
145148
):
146149
return self.average_fps_from_header
147150
return self.num_frames / (

test/test_decoders.py

+41
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,47 @@ def get_some_frames(decoder):
986986
assert_frames_equal(ref_frame3, frames[1].data)
987987
assert_frames_equal(ref_frame5, frames[2].data)
988988

989+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
990+
def test_pts_to_dts_fallback(self, seek_mode):
991+
# Non-regression test for
992+
# https://github.com/pytorch/torchcodec/issues/677 and
993+
# https://github.com/pytorch/torchcodec/issues/676.
994+
# More accurately, this is a non-regression test for videos which do
995+
# *not* specify pts values (all pts values are N/A and set to
996+
# INT64_MIN), but specify *dts* value - which we fallback to.
997+
#
998+
# The test video we have is from
999+
# https://huggingface.co/datasets/raushan-testing-hf/videos-test/blob/main/sample_video_2.avi
1000+
# We can't check it into the repo due to potential licensing issues, so
1001+
# we have to unconditionally skip this test.#
1002+
# TODO: encode a video with no pts values to unskip this test. Couldn't
1003+
# find a way to do that with FFmpeg's CLI, but this should be doable
1004+
# once we have our own video encoder.
1005+
pytest.skip(reason="TODO: Need video with no pts values.")
1006+
1007+
path = "/home/nicolashug/Downloads/sample_video_2.avi"
1008+
decoder = VideoDecoder(path, seek_mode=seek_mode)
1009+
metadata = decoder.metadata
1010+
1011+
assert metadata.average_fps == pytest.approx(29.916667)
1012+
assert metadata.duration_seconds_from_header == 9.02507
1013+
assert metadata.duration_seconds == 9.02507
1014+
assert metadata.begin_stream_seconds_from_content == (
1015+
None if seek_mode == "approximate" else 0
1016+
)
1017+
assert metadata.end_stream_seconds_from_content == (
1018+
None if seek_mode == "approximate" else 9.02507
1019+
)
1020+
1021+
assert decoder[0].shape == (3, 240, 320)
1022+
decoder[10].shape == (3, 240, 320)
1023+
decoder.get_frame_at(2).data.shape == (3, 240, 320)
1024+
decoder.get_frames_at([2, 10]).data.shape == (2, 3, 240, 320)
1025+
decoder.get_frame_played_at(9).data.shape == (3, 240, 320)
1026+
decoder.get_frames_played_at([2, 4]).data.shape == (2, 3, 240, 320)
1027+
with pytest.raises(AssertionError, match="not equal"):
1028+
torch.testing.assert_close(decoder[0], decoder[10])
1029+
9891030

9901031
class TestAudioDecoder:
9911032
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

0 commit comments

Comments
 (0)