Skip to content

Commit a0dcafd

Browse files
committed
Convert sample format, update tests
1 parent 872b569 commit a0dcafd

File tree

3 files changed

+46
-17
lines changed

3 files changed

+46
-17
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ AudioEncoder::AudioEncoder(
9696
// may need to convert the wf into a supported output sample format, which is
9797
// what the `.sample_fmt` defines.
9898
avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec);
99-
printf(
100-
"Will be using: %s\n",
101-
av_get_sample_fmt_name(avCodecContext_->sample_fmt));
10299

103100
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
104101
// planar (fltp).
@@ -143,6 +140,8 @@ AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
143140
// encoder. Right now, the output format we'll choose is just the first format
144141
// in the `sample_fmts` list that the AVCodec defines. Eventually, we may
145142
// allow the user to choose.
143+
// TODO-ENCODING: a better default would probably be to choose the highest
144+
// available precision
146145
if (avCodec.sample_fmts == nullptr) {
147146
// Can't really validate anything in this case, best we can do is hope that
148147
// FLTP is supported by the encoder. If not, FFmpeg will raise.
@@ -164,7 +163,7 @@ void AudioEncoder::encode() {
164163
int numSamplesAllocatedPerFrame =
165164
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
166165
avFrame->nb_samples = numSamplesAllocatedPerFrame;
167-
avFrame->format = avCodecContext_->sample_fmt;
166+
avFrame->format = AV_SAMPLE_FMT_FLTP;
168167
avFrame->sample_rate = avCodecContext_->sample_rate;
169168
avFrame->pts = 0;
170169
setChannelLayout(avFrame, avCodecContext_);
@@ -230,7 +229,29 @@ void AudioEncoder::encode() {
230229

231230
void AudioEncoder::encodeInnerLoop(
232231
AutoAVPacket& autoAVPacket,
233-
const UniqueAVFrame& avFrame) {
232+
const UniqueAVFrame& srcAVFrame) {
233+
bool mustConvert =
234+
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
235+
srcAVFrame != nullptr);
236+
UniqueAVFrame convertedAVFrame;
237+
if (mustConvert) {
238+
if (!swrContext_) {
239+
swrContext_.reset(createSwrContext(
240+
avCodecContext_,
241+
AV_SAMPLE_FMT_FLTP,
242+
avCodecContext_->sample_fmt,
243+
srcAVFrame->sample_rate, // No sample rate conversion
244+
srcAVFrame->sample_rate));
245+
}
246+
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
247+
swrContext_,
248+
srcAVFrame,
249+
avCodecContext_->sample_fmt,
250+
srcAVFrame->sample_rate, // No sample rate conversion
251+
srcAVFrame->sample_rate);
252+
}
253+
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
254+
234255
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
235256
TORCH_CHECK(
236257
status == AVSUCCESS,
@@ -267,6 +288,9 @@ void AudioEncoder::encodeInnerLoop(
267288
}
268289

269290
void AudioEncoder::flushBuffers() {
291+
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing
292+
// swresample is only necessary when converting sample rates, which we don't
293+
// do for encoding.
270294
AutoAVPacket autoAVPacket;
271295
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
272296
}

src/torchcodec/_core/Encoder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ class AudioEncoder {
2626
private:
2727
void encodeInnerLoop(
2828
AutoAVPacket& autoAVPacket,
29-
const UniqueAVFrame& avFrame);
29+
const UniqueAVFrame& srcAVFrame);
3030
void flushBuffers();
3131
AVSampleFormat findOutputSampleFormat(const AVCodec& avCodec);
3232

3333
UniqueEncodingAVFormatContext avFormatContext_;
3434
UniqueAVCodecContext avCodecContext_;
3535
int streamIndex_;
36+
UniqueSwrContext swrContext_;
3637

3738
const torch::Tensor wf_;
3839
};

test/test_ops.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,34 +1122,35 @@ def test_bad_input(self, tmp_path):
11221122
bit_rate=-1, # bad
11231123
)
11241124

1125-
def test_round_trip(self, tmp_path):
1126-
# Check that decode(encode(samples)) == samples
1125+
@pytest.mark.parametrize("output_format", ("wav", "flac"))
1126+
def test_round_trip(self, output_format, tmp_path):
1127+
# Check that decode(encode(samples)) == samples on lossless formats
11271128
asset = NASA_AUDIO_MP3
11281129
source_samples = self.decode(asset)
11291130

1130-
encoded_path = tmp_path / "output.mp3"
1131+
encoded_path = tmp_path / f"output.{output_format}"
11311132
encoder = create_audio_encoder(
11321133
wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path)
11331134
)
11341135
encode_audio(encoder)
11351136

1136-
# TODO-ENCODING: tol should be stricter. We probably need to encode
1137-
# into a lossless format.
11381137
torch.testing.assert_close(
1139-
self.decode(encoded_path), source_samples, rtol=0, atol=0.07
1138+
self.decode(encoded_path), source_samples, rtol=0, atol=1e-4
11401139
)
11411140

1142-
# TODO-ENCODING: test more encoding formats
11431141
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
11441142
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
11451143
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
1146-
def test_against_cli(self, asset, bit_rate, tmp_path):
1144+
@pytest.mark.parametrize("output_format", ("mp3", "wav", "flac"))
1145+
def test_against_cli(self, asset, bit_rate, output_format, tmp_path):
11471146
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
11481147
# that both decoded outputs are equal
11491148

1150-
encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3"
1151-
encoded_by_us = tmp_path / "our_output.mp3"
1149+
encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}"
1150+
encoded_by_us = tmp_path / f"our_output.{output_format}"
11521151

1152+
# Note: output format may be different from ours, e.g. FFmpeg CLI would
1153+
# choose s32 while our current heuristic may choose s16.
11531154
subprocess.run(
11541155
["ffmpeg", "-i", str(asset.path)]
11551156
+ (["-b:a", f"{bit_rate}"] if bit_rate is not None else [])
@@ -1169,7 +1170,10 @@ def test_against_cli(self, asset, bit_rate, tmp_path):
11691170
encode_audio(encoder)
11701171

11711172
torch.testing.assert_close(
1172-
self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us)
1173+
self.decode(encoded_by_ffmpeg),
1174+
self.decode(encoded_by_us),
1175+
rtol=0,
1176+
atol=1e-4,
11731177
)
11741178

11751179

0 commit comments

Comments
 (0)