Skip to content

Commit 87b98e8

Browse files
authored
Encoding: address some TODOs (#667)
1 parent d717d6c commit 87b98e8

File tree

4 files changed

+39
-27
lines changed

4 files changed

+39
-27
lines changed

src/torchcodec/_core/Encoder.cpp

+30-23
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ torch::Tensor validateWf(torch::Tensor wf) {
1414
"waveform must have float32 dtype, got ",
1515
wf.dtype());
1616
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
17+
18+
// We enforce this, but if we get user reports we should investigate whether
19+
// that's actually needed.
20+
int numChannels = static_cast<int>(wf.sizes()[0]);
21+
TORCH_CHECK(
22+
numChannels <= AV_NUM_DATA_POINTERS,
23+
"Trying to encode ",
24+
numChannels,
25+
" channels, but FFmpeg only supports ",
26+
AV_NUM_DATA_POINTERS,
27+
" channels per frame.");
28+
1729
return wf.contiguous();
1830
}
1931

@@ -164,18 +176,7 @@ void AudioEncoder::initializeEncoder(
164176
// what the `.sample_fmt` defines.
165177
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
166178

167-
int numChannels = static_cast<int>(wf_.sizes()[0]);
168-
TORCH_CHECK(
169-
// TODO-ENCODING is this even true / needed? We can probably support more
170-
// with non-planar data?
171-
numChannels <= AV_NUM_DATA_POINTERS,
172-
"Trying to encode ",
173-
numChannels,
174-
" channels, but FFmpeg only supports ",
175-
AV_NUM_DATA_POINTERS,
176-
" channels per frame.");
177-
178-
setDefaultChannelLayout(avCodecContext_, numChannels);
179+
setDefaultChannelLayout(avCodecContext_, static_cast<int>(wf_.sizes()[0]));
179180

180181
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
181182
TORCH_CHECK(
@@ -206,9 +207,12 @@ torch::Tensor AudioEncoder::encodeToTensor() {
206207
}
207208

208209
void AudioEncoder::encode() {
209-
// TODO-ENCODING: Need to check, but consecutive calls to encode() are
210-
// probably invalid. We can address this once we (re)design the public and
211-
// private encoding APIs.
210+
// To be on the safe side we enforce that encode() can only be called once on
211+
// an encoder object. Whether this is actually necessary is unknown, so this
212+
// may be relaxed if needed.
213+
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
214+
encodeWasCalled_ = true;
215+
212216
UniqueAVFrame avFrame(av_frame_alloc());
213217
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
214218
// Default to 256 like in torchaudio
@@ -322,14 +326,17 @@ void AudioEncoder::encodeInnerLoop(
322326
ReferenceAVPacket packet(autoAVPacket);
323327
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
324328
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
325-
// TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
326-
// if (status == AVERROR_EOF) {
327-
// status = av_interleaved_write_frame(avFormatContext_.get(),
328-
// nullptr); TORCH_CHECK(
329-
// status == AVSUCCESS,
330-
// "Failed to flush packet ",
331-
// getFFMPEGErrorStringFromErrorCode(status));
332-
// }
329+
if (status == AVERROR_EOF) {
330+
// Flush the packets that were potentially buffered by
331+
// av_interleaved_write_frame(). See corresponding block in
332+
// TorchAudio:
333+
// https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21
334+
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
335+
TORCH_CHECK(
336+
status == AVSUCCESS,
337+
"Failed to flush packet: ",
338+
getFFMPEGErrorStringFromErrorCode(status));
339+
}
333340
return;
334341
}
335342
TORCH_CHECK(

src/torchcodec/_core/Encoder.h

+2
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,7 @@ class AudioEncoder {
4949

5050
// Stores the AVIOContext for the output tensor buffer.
5151
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
52+
53+
bool encodeWasCalled_ = false;
5254
};
5355
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,6 @@ void encode_audio_to_file(
394394
.encode();
395395
}
396396

397-
// TODO-ENCODING is "format" a good parameter name?? It kinda conflicts with
398-
// "sample_format" which we may eventually want to expose.
399397
at::Tensor encode_audio_to_tensor(
400398
const at::Tensor wf,
401399
int64_t sample_rate,

test/test_ops.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1132,11 +1132,11 @@ def test_bad_input(self, tmp_path):
11321132

11331133
with pytest.raises(RuntimeError, match="No such file or directory"):
11341134
encode_audio_to_file(
1135-
wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3"
1135+
wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
11361136
)
11371137
with pytest.raises(RuntimeError, match="Check the desired extension"):
11381138
encode_audio_to_file(
1139-
wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension"
1139+
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
11401140
)
11411141

11421142
with pytest.raises(RuntimeError, match="invalid sample rate=10"):
@@ -1153,6 +1153,11 @@ def test_bad_input(self, tmp_path):
11531153
bit_rate=-1, # bad
11541154
)
11551155

1156+
with pytest.raises(RuntimeError, match="Trying to encode 10 channels"):
1157+
encode_audio_to_file(
1158+
wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter"
1159+
)
1160+
11561161
@pytest.mark.parametrize(
11571162
"encode_method", (encode_audio_to_file, encode_audio_to_tensor)
11581163
)

0 commit comments

Comments
 (0)