Skip to content

Commit 51e8a0c

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into contiguity
2 parents 49e4187 + 83f3e79 commit 51e8a0c

File tree

3 files changed

+40
-23
lines changed

3 files changed

+40
-23
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -878,8 +878,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
878878
while (!finished) {
879879
try {
880880
UniqueAVFrame avFrame =
881-
decodeAVFrame([startPts](const UniqueAVFrame& avFrame) {
882-
return startPts < avFrame->pts + getDuration(avFrame);
881+
decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) {
882+
return startPts < avFrame->pts + getDuration(avFrame) &&
883+
stopPts > avFrame->pts;
883884
});
884885
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
885886
if (!firstFramePtsSeconds.has_value()) {
@@ -908,9 +909,12 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
908909
TORCH_CHECK(
909910
frames.size() > 0 && firstFramePtsSeconds.has_value(),
910911
"No audio frames were decoded. ",
911-
"This is probably because start_seconds is too high? ",
912-
"Current value is ",
913-
startSeconds);
912+
"This is probably because start_seconds is too high(",
913+
startSeconds,
914+
"),",
915+
"or because stop_seconds(",
916+
stopSecondsOptional,
917+
") is too low.");
914918

915919
return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds};
916920
}

test/test_decoders.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,8 @@ def test_start_equals_stop(self, asset):
11121112

11131113
def test_frame_start_is_not_zero(self):
11141114
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125.
1115-
# So if we request start = 0.05, we shouldn't be truncating anything.
1115+
# So if we request (start, stop) = (0.05, None), we shouldn't be
1116+
# truncating anything.
11161117

11171118
asset = NASA_AUDIO_MP3
11181119
start_seconds = 0.05 # this is less than the first frame's pts
@@ -1128,6 +1129,35 @@ def test_frame_start_is_not_zero(self):
11281129
reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
11291130
torch.testing.assert_close(samples.data, reference_frames)
11301131

1132+
# Non-regression test for https://github.com/pytorch/torchcodec/issues/567
1133+
# If we ask for start < stop <= first_frame_pts, we should raise.
1134+
with pytest.raises(RuntimeError, match="No audio frames were decoded"):
1135+
decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=0.05)
1136+
1137+
first_frame_pts_seconds = asset.get_frame_info(idx=0).pts_seconds
1138+
with pytest.raises(RuntimeError, match="No audio frames were decoded"):
1139+
decoder.get_samples_played_in_range(
1140+
start_seconds=0, stop_seconds=first_frame_pts_seconds
1141+
)
1142+
1143+
# Documenting an edge case: we ask for samples barely beyond the start
1144+
# of the first frame. The C++ decoder returns the first frame, which
1145+
# gets (correctly!) truncated by the AudioDecoder, and we end up with
1146+
# empty data.
1147+
samples = decoder.get_samples_played_in_range(
1148+
start_seconds=0, stop_seconds=first_frame_pts_seconds + 1e-5
1149+
)
1150+
assert samples.data.shape == (2, 0)
1151+
assert samples.pts_seconds == first_frame_pts_seconds
1152+
assert samples.duration_seconds == 0
1153+
1154+
# if we ask for a little bit more samples, we get non-empty data
1155+
samples = decoder.get_samples_played_in_range(
1156+
start_seconds=0, stop_seconds=first_frame_pts_seconds + 1e-3
1157+
)
1158+
assert samples.data.shape == (2, 8)
1159+
assert samples.pts_seconds == first_frame_pts_seconds
1160+
11311161
def test_single_channel(self):
11321162
asset = SINE_MONO_S32
11331163
decoder = AudioDecoder(asset.path)

test/test_ops.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -884,23 +884,6 @@ def test_pts(self, asset):
884884

885885
assert pts_seconds == start_seconds
886886

887-
def test_decode_before_frame_start(self):
888-
# Test illustrating bug described in
889-
# https://github.com/pytorch/torchcodec/issues/567
890-
asset = NASA_AUDIO_MP3
891-
892-
decoder = create_from_file(str(asset.path), seek_mode="approximate")
893-
add_audio_stream(decoder)
894-
895-
frames, *_ = get_frames_by_pts_in_range_audio(
896-
decoder, start_seconds=0, stop_seconds=0.05
897-
)
898-
all_frames, *_ = get_frames_by_pts_in_range_audio(
899-
decoder, start_seconds=0, stop_seconds=None
900-
)
901-
# TODO fix this. `frames` should be empty.
902-
torch.testing.assert_close(frames, all_frames)
903-
904887
def test_sample_rate_conversion(self):
905888
def get_all_frames(asset, sample_rate=None, stop_seconds=None):
906889
decoder = create_from_file(str(asset.path), seek_mode="approximate")

0 commit comments

Comments
 (0)