Skip to content

Encoding: support wav, flac etc. #630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 82 additions & 19 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ AudioEncoder::AudioEncoder(
wf_.dtype() == torch::kFloat32,
"waveform must have float32 dtype, got ",
wf_.dtype());
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
// planar (fltp).
TORCH_CHECK(
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());

Expand Down Expand Up @@ -92,14 +94,10 @@ AudioEncoder::AudioEncoder(
validateSampleRate(*avCodec, sampleRate);
avCodecContext_->sample_rate = sampleRate;

// Note: This is the format of the **input** waveform. This doesn't determine
// the output.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original comment was wrong: that's not the format of the input waveform. It's the format of the input AVFrame that we pass to avcodec_send_frame(). And it needs to be a format that the codec supports.

// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
// planar.
// TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will
// raise. We need to handle this, probably converting the format with
// libswresample.
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
// may need to convert the wf into a supported output sample format, which is
// what the `.sample_fmt` defines.
avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec);

int numChannels = static_cast<int>(wf_.sizes()[0]);
TORCH_CHECK(
Expand All @@ -120,12 +118,6 @@ AudioEncoder::AudioEncoder(
"avcodec_open2 failed: ",
getFFMPEGErrorStringFromErrorCode(status));

TORCH_CHECK(
avCodecContext_->frame_size > 0,
"frame_size is ",
avCodecContext_->frame_size,
". Cannot encode. This should probably never happen?");
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't always be non-zero, see below.

Copy link
Contributor

@scotts scotts Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - it might be worth noting why from the docs (https://ffmpeg.org/doxygen/6.0/structAVCodecContext.html#aec57f0d859a6df8b479cd93ca3a44a33, which I admit to not understanding) when we turn 0 into our default.


// We're allocating the stream here. Streams are meant to be freed by
// avformat_free_context(avFormatContext), which we call in the
// avFormatContext_'s destructor.
Expand All @@ -140,11 +132,50 @@ AudioEncoder::AudioEncoder(
streamIndex_ = avStream->index;
}

AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
// Find a sample format that the encoder supports. We prefer using FLT[P],
// since this is the format of the input waveform. If FLTP isn't supported
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
// into the format with the highest resolution.
if (avCodec.sample_fmts == nullptr) {
// Can't really validate anything in this case, best we can do is hope that
// FLTP is supported by the encoder. If not, FFmpeg will raise.
return AV_SAMPLE_FMT_FLTP;
}

std::vector<AVSampleFormat> preferredFormatsOrder = {
AV_SAMPLE_FMT_FLTP,
AV_SAMPLE_FMT_FLT,
AV_SAMPLE_FMT_DBLP,
AV_SAMPLE_FMT_DBL,
AV_SAMPLE_FMT_S64P,
AV_SAMPLE_FMT_S64,
AV_SAMPLE_FMT_S32P,
AV_SAMPLE_FMT_S32,
AV_SAMPLE_FMT_S16P,
AV_SAMPLE_FMT_S16,
AV_SAMPLE_FMT_U8P,
AV_SAMPLE_FMT_U8};

for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) {
if (avCodec.sample_fmts[i] == preferredFormat) {
return preferredFormat;
}
}
}
// Should never happen, but just in case
return avCodec.sample_fmts[0];
}

void AudioEncoder::encode() {
UniqueAVFrame avFrame(av_frame_alloc());
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
avFrame->nb_samples = avCodecContext_->frame_size;
avFrame->format = avCodecContext_->sample_fmt;
// Default to 256 like in torchaudio
int numSamplesAllocatedPerFrame =
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
avFrame->nb_samples = numSamplesAllocatedPerFrame;
avFrame->format = AV_SAMPLE_FMT_FLTP;
avFrame->sample_rate = avCodecContext_->sample_rate;
avFrame->pts = 0;
setChannelLayout(avFrame, avCodecContext_);
Expand All @@ -160,7 +191,6 @@ void AudioEncoder::encode() {
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
int numEncodedSamples = 0; // per channel
int numSamplesPerFrame = avCodecContext_->frame_size; // per channel
int numBytesPerSample = static_cast<int>(wf_.element_size());
int numBytesPerChannel = numSamples * numBytesPerSample;

Expand All @@ -178,7 +208,7 @@ void AudioEncoder::encode() {
getFFMPEGErrorStringFromErrorCode(status));

int numSamplesToEncode =
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;

for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
Expand Down Expand Up @@ -211,7 +241,37 @@ void AudioEncoder::encode() {

void AudioEncoder::encodeInnerLoop(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame) {
const UniqueAVFrame& srcAVFrame) {
bool mustConvert =
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
srcAVFrame != nullptr);
UniqueAVFrame convertedAVFrame;
if (mustConvert) {
if (!swrContext_) {
swrContext_.reset(createSwrContext(
avCodecContext_,
AV_SAMPLE_FMT_FLTP,
avCodecContext_->sample_fmt,
srcAVFrame->sample_rate, // No sample rate conversion
srcAVFrame->sample_rate));
}
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
swrContext_,
srcAVFrame,
avCodecContext_->sample_fmt,
srcAVFrame->sample_rate, // No sample rate conversion
srcAVFrame->sample_rate);
TORCH_CHECK(
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
"convertedAVFrame->nb_samples=",
convertedAVFrame->nb_samples,
" differs from ",
"srcAVFrame->nb_samples=",
srcAVFrame->nb_samples,
"This is unexpected, please report on the TorchCodec bug tracker.");
}
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;

auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
TORCH_CHECK(
status == AVSUCCESS,
Expand Down Expand Up @@ -248,6 +308,9 @@ void AudioEncoder::encodeInnerLoop(
}

void AudioEncoder::flushBuffers() {
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing
// swresample is only necessary when converting sample rates, which we don't
// do for encoding.
AutoAVPacket autoAVPacket;
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
}
Expand Down
4 changes: 3 additions & 1 deletion src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ class AudioEncoder {
private:
void encodeInnerLoop(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame);
const UniqueAVFrame& srcAVFrame);
void flushBuffers();
AVSampleFormat findOutputSampleFormat(const AVCodec& avCodec);

UniqueEncodingAVFormatContext avFormatContext_;
UniqueAVCodecContext avCodecContext_;
int streamIndex_;
UniqueSwrContext swrContext_;

const torch::Tensor wf_;
};
Expand Down
34 changes: 23 additions & 11 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .utils import (
assert_frames_equal,
cpu_and_cuda,
get_ffmpeg_major_version,
in_fbcode,
NASA_AUDIO,
NASA_AUDIO_MP3,
Expand Down Expand Up @@ -1122,33 +1123,40 @@ def test_bad_input(self, tmp_path):
bit_rate=-1, # bad
)

def test_round_trip(self, tmp_path):
# Check that decode(encode(samples)) == samples
@pytest.mark.parametrize("output_format", ("wav", "flac"))
def test_round_trip(self, output_format, tmp_path):
# Check that decode(encode(samples)) == samples on lossless formats

if get_ffmpeg_major_version() == 4 and output_format == "wav":
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, we're hitting this:

status = swr_init(swrContext);
TORCH_CHECK(
status == AVSUCCESS,
"Couldn't initialize SwrContext: ",
getFFMPEGErrorStringFromErrorCode(status),
". If the error says 'Invalid argument', it's likely that you are using "
"a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
"valid scenarios. Try to upgrade FFmpeg?");


asset = NASA_AUDIO_MP3
source_samples = self.decode(asset)

encoded_path = tmp_path / "output.mp3"
encoded_path = tmp_path / f"output.{output_format}"
encoder = create_audio_encoder(
wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path)
)
encode_audio(encoder)

# TODO-ENCODING: tol should be stricter. We probably need to encode
# into a lossless format.
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None)
torch.testing.assert_close(
self.decode(encoded_path), source_samples, rtol=0, atol=0.07
self.decode(encoded_path), source_samples, rtol=rtol, atol=atol
)

# TODO-ENCODING: test more encoding formats
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
def test_against_cli(self, asset, bit_rate, tmp_path):
@pytest.mark.parametrize("output_format", ("mp3", "wav", "flac"))
def test_against_cli(self, asset, bit_rate, output_format, tmp_path):
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
# that both decoded outputs are equal

encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3"
encoded_by_us = tmp_path / "our_output.mp3"
if get_ffmpeg_major_version() == 4 and output_format == "wav":
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")

encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}"
encoded_by_us = tmp_path / f"our_output.{output_format}"

subprocess.run(
["ffmpeg", "-i", str(asset.path)]
Expand All @@ -1168,8 +1176,12 @@ def test_against_cli(self, asset, bit_rate, tmp_path):
)
encode_audio(encoder)

rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None)
torch.testing.assert_close(
self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us)
self.decode(encoded_by_ffmpeg),
self.decode(encoded_by_us),
rtol=rtol,
atol=atol,
)


Expand Down
Loading