@@ -14,6 +14,18 @@ torch::Tensor validateWf(torch::Tensor wf) {
14
14
" waveform must have float32 dtype, got " ,
15
15
wf.dtype ());
16
16
TORCH_CHECK (wf.dim () == 2 , " waveform must have 2 dimensions, got " , wf.dim ());
17
+
18
+ // We enforce this, but if we get user reports we should investigate whether
19
+ // that's actually needed.
20
+ int numChannels = static_cast <int >(wf.sizes ()[0 ]);
21
+ TORCH_CHECK (
22
+ numChannels <= AV_NUM_DATA_POINTERS,
23
+ " Trying to encode " ,
24
+ numChannels,
25
+ " channels, but FFmpeg only supports " ,
26
+ AV_NUM_DATA_POINTERS,
27
+ " channels per frame." );
28
+
17
29
return wf.contiguous ();
18
30
}
19
31
@@ -164,18 +176,7 @@ void AudioEncoder::initializeEncoder(
164
176
// what the `.sample_fmt` defines.
165
177
avCodecContext_->sample_fmt = findBestOutputSampleFormat (*avCodec);
166
178
167
- int numChannels = static_cast <int >(wf_.sizes ()[0 ]);
168
- TORCH_CHECK (
169
- // TODO-ENCODING is this even true / needed? We can probably support more
170
- // with non-planar data?
171
- numChannels <= AV_NUM_DATA_POINTERS,
172
- " Trying to encode " ,
173
- numChannels,
174
- " channels, but FFmpeg only supports " ,
175
- AV_NUM_DATA_POINTERS,
176
- " channels per frame." );
177
-
178
- setDefaultChannelLayout (avCodecContext_, numChannels);
179
+ setDefaultChannelLayout (avCodecContext_, static_cast <int >(wf_.sizes ()[0 ]));
179
180
180
181
int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
181
182
TORCH_CHECK (
@@ -206,9 +207,12 @@ torch::Tensor AudioEncoder::encodeToTensor() {
206
207
}
207
208
208
209
void AudioEncoder::encode () {
209
- // TODO-ENCODING: Need to check, but consecutive calls to encode() are
210
- // probably invalid. We can address this once we (re)design the public and
211
- // private encoding APIs.
210
+ // To be on the safe side we enforce that encode() can only be called once on
211
+ // an encoder object. Whether this is actually necessary is unknown, so this
212
+ // may be relaxed if needed.
213
+ TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
214
+ encodeWasCalled_ = true ;
215
+
212
216
UniqueAVFrame avFrame (av_frame_alloc ());
213
217
TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
214
218
// Default to 256 like in torchaudio
@@ -322,14 +326,17 @@ void AudioEncoder::encodeInnerLoop(
322
326
ReferenceAVPacket packet (autoAVPacket);
323
327
status = avcodec_receive_packet (avCodecContext_.get (), packet.get ());
324
328
if (status == AVERROR (EAGAIN) || status == AVERROR_EOF) {
325
- // TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
326
- // if (status == AVERROR_EOF) {
327
- // status = av_interleaved_write_frame(avFormatContext_.get(),
328
- // nullptr); TORCH_CHECK(
329
- // status == AVSUCCESS,
330
- // "Failed to flush packet ",
331
- // getFFMPEGErrorStringFromErrorCode(status));
332
- // }
329
+ if (status == AVERROR_EOF) {
330
+ // Flush the packets that were potentially buffered by
331
+ // av_interleaved_write_frame(). See corresponding block in
332
+ // TorchAudio:
333
+ // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21
334
+ status = av_interleaved_write_frame (avFormatContext_.get (), nullptr );
335
+ TORCH_CHECK (
336
+ status == AVSUCCESS,
337
+ " Failed to flush packet: " ,
338
+ getFFMPEGErrorStringFromErrorCode (status));
339
+ }
333
340
return ;
334
341
}
335
342
TORCH_CHECK (
0 commit comments