Skip to content

Commit 721c315

Browse files
authored
AudioDecoder: specify desired num_channels (#678)
1 parent 0cced7c commit 721c315

File tree

9 files changed

+173
-40
lines changed

9 files changed

+173
-40
lines changed

src/torchcodec/_core/Encoder.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,20 @@ void AudioEncoder::encodeInnerLoop(
293293
if (mustConvert) {
294294
if (!swrContext_) {
295295
swrContext_.reset(createSwrContext(
296-
avCodecContext_,
297296
AV_SAMPLE_FMT_FLTP,
298297
avCodecContext_->sample_fmt,
299298
srcAVFrame->sample_rate, // No sample rate conversion
300-
srcAVFrame->sample_rate));
299+
srcAVFrame->sample_rate,
300+
srcAVFrame,
301+
getNumChannels(srcAVFrame) // No num_channel conversion
302+
));
301303
}
302-
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
304+
convertedAVFrame = convertAudioAVFrameSamples(
303305
swrContext_,
304306
srcAVFrame,
305307
avCodecContext_->sample_fmt,
306308
srcAVFrame->sample_rate, // No sample rate conversion
307-
srcAVFrame->sample_rate);
309+
getNumChannels(srcAVFrame)); // No num_channel conversion
308310
TORCH_CHECK(
309311
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
310312
"convertedAVFrame->nb_samples=",

src/torchcodec/_core/FFMPEGCommon.cpp

+66-16
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ void setDefaultChannelLayout(
8181
AVChannelLayout channel_layout;
8282
av_channel_layout_default(&channel_layout, numChannels);
8383
avCodecContext->ch_layout = channel_layout;
84-
8584
#else
8685
uint64_t channel_layout = av_get_default_channel_layout(numChannels);
8786
avCodecContext->channel_layout = channel_layout;
@@ -106,32 +105,79 @@ void setChannelLayout(
106105
#endif
107106
}
108107

108+
namespace {
109+
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
110+
111+
// Returns:
112+
// - the srcAVFrame's channel layout if srcAVFrame has desiredNumChannels
113+
// - the default channel layout with desiredNumChannels otherwise.
114+
AVChannelLayout getDesiredChannelLayout(
115+
int desiredNumChannels,
116+
const UniqueAVFrame& srcAVFrame) {
117+
AVChannelLayout desiredLayout;
118+
if (desiredNumChannels == getNumChannels(srcAVFrame)) {
119+
desiredLayout = srcAVFrame->ch_layout;
120+
} else {
121+
av_channel_layout_default(&desiredLayout, desiredNumChannels);
122+
}
123+
return desiredLayout;
124+
}
125+
126+
#else
127+
128+
// Same as above
129+
int64_t getDesiredChannelLayout(
130+
int desiredNumChannels,
131+
const UniqueAVFrame& srcAVFrame) {
132+
int64_t desiredLayout;
133+
if (desiredNumChannels == getNumChannels(srcAVFrame)) {
134+
desiredLayout = srcAVFrame->channel_layout;
135+
} else {
136+
desiredLayout = av_get_default_channel_layout(desiredNumChannels);
137+
}
138+
return desiredLayout;
139+
}
140+
#endif
141+
} // namespace
142+
143+
// Sets dstAVFrame' channel layout to getDesiredChannelLayout(): see doc above
109144
void setChannelLayout(
110145
UniqueAVFrame& dstAVFrame,
111-
const UniqueAVFrame& srcAVFrame) {
146+
const UniqueAVFrame& srcAVFrame,
147+
int desiredNumChannels) {
112148
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
113-
dstAVFrame->ch_layout = srcAVFrame->ch_layout;
149+
AVChannelLayout desiredLayout =
150+
getDesiredChannelLayout(desiredNumChannels, srcAVFrame);
151+
auto status = av_channel_layout_copy(&dstAVFrame->ch_layout, &desiredLayout);
152+
TORCH_CHECK(
153+
status == AVSUCCESS,
154+
"Couldn't copy channel layout to avFrame: ",
155+
getFFMPEGErrorStringFromErrorCode(status));
114156
#else
115-
dstAVFrame->channel_layout = srcAVFrame->channel_layout;
157+
dstAVFrame->channel_layout =
158+
getDesiredChannelLayout(desiredNumChannels, srcAVFrame);
159+
dstAVFrame->channels = desiredNumChannels;
116160
#endif
117161
}
118162

119163
SwrContext* createSwrContext(
120-
UniqueAVCodecContext& avCodecContext,
121164
AVSampleFormat sourceSampleFormat,
122165
AVSampleFormat desiredSampleFormat,
123166
int sourceSampleRate,
124-
int desiredSampleRate) {
167+
int desiredSampleRate,
168+
const UniqueAVFrame& srcAVFrame,
169+
int desiredNumChannels) {
125170
SwrContext* swrContext = nullptr;
126171
int status = AVSUCCESS;
127172
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
128-
AVChannelLayout layout = avCodecContext->ch_layout;
173+
AVChannelLayout desiredLayout =
174+
getDesiredChannelLayout(desiredNumChannels, srcAVFrame);
129175
status = swr_alloc_set_opts2(
130176
&swrContext,
131-
&layout,
177+
&desiredLayout,
132178
desiredSampleFormat,
133179
desiredSampleRate,
134-
&layout,
180+
&srcAVFrame->ch_layout,
135181
sourceSampleFormat,
136182
sourceSampleRate,
137183
0,
@@ -142,13 +188,14 @@ SwrContext* createSwrContext(
142188
"Couldn't create SwrContext: ",
143189
getFFMPEGErrorStringFromErrorCode(status));
144190
#else
145-
int64_t layout = static_cast<int64_t>(avCodecContext->channel_layout);
191+
int64_t desiredLayout =
192+
getDesiredChannelLayout(desiredNumChannels, srcAVFrame);
146193
swrContext = swr_alloc_set_opts(
147194
nullptr,
148-
layout,
195+
desiredLayout,
149196
desiredSampleFormat,
150197
desiredSampleRate,
151-
layout,
198+
srcAVFrame->channel_layout,
152199
sourceSampleFormat,
153200
sourceSampleRate,
154201
0,
@@ -167,20 +214,21 @@ SwrContext* createSwrContext(
167214
return swrContext;
168215
}
169216

170-
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
217+
UniqueAVFrame convertAudioAVFrameSamples(
171218
const UniqueSwrContext& swrContext,
172219
const UniqueAVFrame& srcAVFrame,
173220
AVSampleFormat desiredSampleFormat,
174-
int sourceSampleRate,
175-
int desiredSampleRate) {
221+
int desiredSampleRate,
222+
int desiredNumChannels) {
176223
UniqueAVFrame convertedAVFrame(av_frame_alloc());
177224
TORCH_CHECK(
178225
convertedAVFrame,
179226
"Could not allocate frame for sample format conversion.");
180227

181-
setChannelLayout(convertedAVFrame, srcAVFrame);
182228
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
229+
183230
convertedAVFrame->sample_rate = desiredSampleRate;
231+
int sourceSampleRate = srcAVFrame->sample_rate;
184232
if (sourceSampleRate != desiredSampleRate) {
185233
// Note that this is an upper bound on the number of output samples.
186234
// `swr_convert()` will likely not fill convertedAVFrame with that many
@@ -200,6 +248,8 @@ UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
200248
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
201249
}
202250

251+
setChannelLayout(convertedAVFrame, srcAVFrame, desiredNumChannels);
252+
203253
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
204254
TORCH_CHECK(
205255
status == AVSUCCESS,

src/torchcodec/_core/FFMPEGCommon.h

+15-7
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,28 @@ void setChannelLayout(
157157

158158
void setChannelLayout(
159159
UniqueAVFrame& dstAVFrame,
160-
const UniqueAVFrame& srcAVFrame);
160+
const UniqueAVFrame& srcAVFrame,
161+
int desiredNumChannels);
162+
161163
SwrContext* createSwrContext(
162-
UniqueAVCodecContext& avCodecContext,
163164
AVSampleFormat sourceSampleFormat,
164165
AVSampleFormat desiredSampleFormat,
165166
int sourceSampleRate,
166-
int desiredSampleRate);
167-
168-
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
167+
int desiredSampleRate,
168+
const UniqueAVFrame& srcAVFrame,
169+
int desiredNumChannels);
170+
171+
// Converts, if needed:
172+
// - sample format
173+
// - sample rate
174+
// - number of channels.
175+
// createSwrContext must have been previously called with matching parameters.
176+
UniqueAVFrame convertAudioAVFrameSamples(
169177
const UniqueSwrContext& swrContext,
170178
const UniqueAVFrame& srcAVFrame,
171179
AVSampleFormat desiredSampleFormat,
172-
int sourceSampleRate,
173-
int desiredSampleRate);
180+
int desiredSampleRate,
181+
int desiredNumChannels);
174182

175183
// Returns true if sws_scale can handle unaligned data.
176184
bool canSwsScaleHandleUnalignedData();

src/torchcodec/_core/SingleStreamDecoder.cpp

+40-8
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,13 @@ void SingleStreamDecoder::addAudioStream(
453453
TORCH_CHECK(
454454
seekMode_ == SeekMode::approximate,
455455
"seek_mode must be 'approximate' for audio streams.");
456+
if (audioStreamOptions.numChannels.has_value()) {
457+
TORCH_CHECK(
458+
*audioStreamOptions.numChannels > 0 &&
459+
*audioStreamOptions.numChannels <= AV_NUM_DATA_POINTERS,
460+
"num_channels must be > 0 and <= AV_NUM_DATA_POINTERS (usually 8). Got: ",
461+
*audioStreamOptions.numChannels);
462+
}
456463

457464
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
458465

@@ -1171,27 +1178,42 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
11711178
int desiredSampleRate =
11721179
streamInfo.audioStreamOptions.sampleRate.value_or(sourceSampleRate);
11731180

1181+
int sourceNumChannels = getNumChannels(streamInfo.codecContext);
1182+
TORCH_CHECK(
1183+
sourceNumChannels == getNumChannels(srcAVFrame),
1184+
"The frame has ",
1185+
getNumChannels(srcAVFrame),
1186+
" channels, expected ",
1187+
sourceNumChannels,
1188+
". If you are hitting this, it may be because you are using "
1189+
"a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
1190+
"valid scenarios. Try to upgrade FFmpeg?");
1191+
int desiredNumChannels =
1192+
streamInfo.audioStreamOptions.numChannels.value_or(sourceNumChannels);
1193+
11741194
bool mustConvert =
11751195
(sourceSampleFormat != desiredSampleFormat ||
1176-
sourceSampleRate != desiredSampleRate);
1196+
sourceSampleRate != desiredSampleRate ||
1197+
sourceNumChannels != desiredNumChannels);
11771198

11781199
UniqueAVFrame convertedAVFrame;
11791200
if (mustConvert) {
11801201
if (!streamInfo.swrContext) {
11811202
streamInfo.swrContext.reset(createSwrContext(
1182-
streamInfo.codecContext,
11831203
sourceSampleFormat,
11841204
desiredSampleFormat,
11851205
sourceSampleRate,
1186-
desiredSampleRate));
1206+
desiredSampleRate,
1207+
srcAVFrame,
1208+
desiredNumChannels));
11871209
}
11881210

1189-
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
1211+
convertedAVFrame = convertAudioAVFrameSamples(
11901212
streamInfo.swrContext,
11911213
srcAVFrame,
11921214
desiredSampleFormat,
1193-
sourceSampleRate,
1194-
desiredSampleRate);
1215+
desiredSampleRate,
1216+
desiredNumChannels);
11951217
}
11961218
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
11971219

@@ -1204,8 +1226,17 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
12041226
"source format = ",
12051227
av_get_sample_fmt_name(format));
12061228

1229+
int numChannels = getNumChannels(avFrame);
1230+
TORCH_CHECK(
1231+
numChannels == desiredNumChannels,
1232+
"Something went wrong, the frame didn't get converted to the desired ",
1233+
"number of channels = ",
1234+
desiredNumChannels,
1235+
". Got ",
1236+
numChannels,
1237+
" instead.");
1238+
12071239
auto numSamples = avFrame->nb_samples; // per channel
1208-
auto numChannels = getNumChannels(avFrame);
12091240

12101241
frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32);
12111242

@@ -1240,7 +1271,8 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
12401271
return std::nullopt;
12411272
}
12421273

1243-
auto numChannels = getNumChannels(streamInfo.codecContext);
1274+
int numChannels = streamInfo.audioStreamOptions.numChannels.value_or(
1275+
getNumChannels(streamInfo.codecContext));
12441276
torch::Tensor lastSamples =
12451277
torch::empty({numChannels, numRemainingSamples}, torch::kFloat32);
12461278

src/torchcodec/_core/StreamOptions.h

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct AudioStreamOptions {
4444
AudioStreamOptions() {}
4545

4646
std::optional<int> sampleRate;
47+
std::optional<int> numChannels;
4748
};
4849

4950
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
m.def(
4141
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()");
4242
m.def(
43-
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None) -> ()");
43+
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
4444
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
4545
m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)");
4646
m.def(
@@ -280,9 +280,11 @@ void add_video_stream(
280280
void add_audio_stream(
281281
at::Tensor& decoder,
282282
std::optional<int64_t> stream_index = std::nullopt,
283-
std::optional<int64_t> sample_rate = std::nullopt) {
283+
std::optional<int64_t> sample_rate = std::nullopt,
284+
std::optional<int64_t> num_channels = std::nullopt) {
284285
AudioStreamOptions audioStreamOptions;
285286
audioStreamOptions.sampleRate = sample_rate;
287+
audioStreamOptions.numChannels = num_channels;
286288

287289
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
288290
videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions);

src/torchcodec/_core/ops.py

+2
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def add_audio_stream_abstract(
221221
decoder: torch.Tensor,
222222
*,
223223
stream_index: Optional[int] = None,
224+
sample_rate: Optional[int] = None,
225+
num_channels: Optional[int] = None,
224226
) -> None:
225227
return
226228

src/torchcodec/decoders/_audio_decoder.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ class AudioDecoder:
3939
Note that this index is absolute across all media types. If left unspecified, then
4040
the :term:`best stream` is used.
4141
sample_rate (int, optional): The desired output sample rate of the decoded samples.
42-
By default, the samples are returned in their original sample rate.
42+
By default, the sample rate of the source is used.
43+
num_channels (int, optional): The desired number of channels of the decoded samples.
44+
By default, the number of channels of the source is used.
4345
4446
Attributes:
4547
metadata (AudioStreamMetadata): Metadata of the audio stream.
@@ -54,11 +56,15 @@ def __init__(
5456
*,
5557
stream_index: Optional[int] = None,
5658
sample_rate: Optional[int] = None,
59+
num_channels: Optional[int] = None,
5760
):
5861
self._decoder = create_decoder(source=source, seek_mode="approximate")
5962

6063
core.add_audio_stream(
61-
self._decoder, stream_index=stream_index, sample_rate=sample_rate
64+
self._decoder,
65+
stream_index=stream_index,
66+
sample_rate=sample_rate,
67+
num_channels=num_channels,
6268
)
6369

6470
container_metadata = core.get_container_metadata(self._decoder)

0 commit comments

Comments
 (0)