Skip to content

Commit 99d21db

Browse files
authored
Audio: Fix output shape when start==stop (#664)
1 parent 51d1135 commit 99d21db

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -848,13 +848,14 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
848848
std::to_string(*stopSecondsOptional) + ").");
849849
}
850850

851+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
852+
851853
if (stopSecondsOptional.has_value() && startSeconds == *stopSecondsOptional) {
852854
// For consistency with video
853-
return AudioFramesOutput{torch::empty({0, 0}), 0.0};
855+
int numChannels = getNumChannels(streamInfo.codecContext);
856+
return AudioFramesOutput{torch::empty({numChannels, 0}), 0.0};
854857
}
855858

856-
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
857-
858859
auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
859860
if (startPts < streamInfo.lastDecodedAvFramePts +
860861
streamInfo.lastDecodedAvFrameDuration) {

test/test_decoders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ def test_not_at_frame_boundaries(self, asset):
11081108
def test_start_equals_stop(self, asset):
11091109
decoder = AudioDecoder(asset.path)
11101110
samples = decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=3)
1111-
assert samples.data.shape == (0, 0)
1111+
assert samples.data.shape == (asset.num_channels, 0)
11121112

11131113
def test_frame_start_is_not_zero(self):
11141114
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125.

test/test_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ def test_decode_start_equal_stop(self, asset):
782782
frames, pts_seconds = get_frames_by_pts_in_range_audio(
783783
decoder, start_seconds=1, stop_seconds=1
784784
)
785-
assert frames.shape == (0, 0)
785+
assert frames.shape == (asset.num_channels, 0)
786786
assert pts_seconds == 0
787787

788788
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))

0 commit comments

Comments
 (0)