Skip to content

AudioDecoder: Fix output when stop < start of first frame #665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
while (!finished) {
try {
UniqueAVFrame avFrame =
decodeAVFrame([startPts](const UniqueAVFrame& avFrame) {
return startPts < avFrame->pts + getDuration(avFrame);
decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) {
return startPts < avFrame->pts + getDuration(avFrame) &&
stopPts > avFrame->pts;
});
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
if (!firstFramePtsSeconds.has_value()) {
Expand Down Expand Up @@ -907,9 +908,12 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
TORCH_CHECK(
frames.size() > 0 && firstFramePtsSeconds.has_value(),
"No audio frames were decoded. ",
"This is probably because start_seconds is too high? ",
"Current value is ",
startSeconds);
"This is probably because start_seconds is too high(",
startSeconds,
"),",
"or because stop_seconds(",
stopSecondsOptional,
") is too low.");

return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds};
}
Expand Down
32 changes: 31 additions & 1 deletion test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,8 @@ def test_start_equals_stop(self, asset):

def test_frame_start_is_not_zero(self):
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125.
# So if we request start = 0.05, we shouldn't be truncating anything.
# So if we request (start, stop) = (0.05, None), we shouldn't be
# truncating anything.

asset = NASA_AUDIO_MP3
start_seconds = 0.05 # this is less than the first frame's pts
Expand All @@ -1128,6 +1129,35 @@ def test_frame_start_is_not_zero(self):
reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
torch.testing.assert_close(samples.data, reference_frames)

# Non-regression test for https://github.com/pytorch/torchcodec/issues/567
# If we ask for start < stop <= first_frame_pts, we should raise.
with pytest.raises(RuntimeError, match="No audio frames were decoded"):
decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=0.05)

first_frame_pts_seconds = asset.get_frame_info(idx=0).pts_seconds
with pytest.raises(RuntimeError, match="No audio frames were decoded"):
decoder.get_samples_played_in_range(
start_seconds=0, stop_seconds=first_frame_pts_seconds
)

# Documenting an edge case: we ask for samples barely beyond the start
# of the first frame. The C++ decoder returns the first frame, which
# gets (correctly!) truncated by the AudioDecoder, and we end up with
# empty data.
samples = decoder.get_samples_played_in_range(
start_seconds=0, stop_seconds=first_frame_pts_seconds + 1e-5
)
assert samples.data.shape == (2, 0)
assert samples.pts_seconds == first_frame_pts_seconds
assert samples.duration_seconds == 0

# if we ask for a little bit more samples, we get non-empty data
samples = decoder.get_samples_played_in_range(
start_seconds=0, stop_seconds=first_frame_pts_seconds + 1e-3
)
assert samples.data.shape == (2, 8)
assert samples.pts_seconds == first_frame_pts_seconds

def test_single_channel(self):
asset = SINE_MONO_S32
decoder = AudioDecoder(asset.path)
Expand Down
17 changes: 0 additions & 17 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,23 +884,6 @@ def test_pts(self, asset):

assert pts_seconds == start_seconds

def test_decode_before_frame_start(self):
# Test illustrating bug described in
# https://github.com/pytorch/torchcodec/issues/567
asset = NASA_AUDIO_MP3

decoder = create_from_file(str(asset.path), seek_mode="approximate")
add_audio_stream(decoder)

frames, *_ = get_frames_by_pts_in_range_audio(
decoder, start_seconds=0, stop_seconds=0.05
)
all_frames, *_ = get_frames_by_pts_in_range_audio(
decoder, start_seconds=0, stop_seconds=None
)
# TODO fix this. `frames` should be empty.
torch.testing.assert_close(frames, all_frames)

def test_sample_rate_conversion(self):
def get_all_frames(asset, sample_rate=None, stop_seconds=None):
decoder = create_from_file(str(asset.path), seek_mode="approximate")
Expand Down
Loading