-
Notifications
You must be signed in to change notification settings - Fork 35
Encoding: support wav, flac etc. #630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
7921558
2a19014
73bdc85
54f5543
24842b6
c3ac80a
5b39c8f
1f9f904
f525848
9150137
872b569
a0dcafd
f49d507
ee3a199
485ee2e
27fdbac
8467b92
9ab1bd1
67dba5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,8 @@ AudioEncoder::AudioEncoder( | |
wf_.dtype() == torch::kFloat32, | ||
"waveform must have float32 dtype, got ", | ||
wf_.dtype()); | ||
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed | ||
// planar (fltp). | ||
TORCH_CHECK( | ||
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim()); | ||
|
||
|
@@ -92,14 +94,10 @@ AudioEncoder::AudioEncoder( | |
validateSampleRate(*avCodec, sampleRate); | ||
avCodecContext_->sample_rate = sampleRate; | ||
|
||
// Note: This is the format of the **input** waveform. This doesn't determine | ||
// the output. | ||
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed | ||
// planar. | ||
// TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will | ||
// raise. We need to handle this, probably converting the format with | ||
// libswresample. | ||
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; | ||
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we | ||
// may need to convert the wf into a supported output sample format, which is | ||
// what the `.sample_fmt` defines. | ||
avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec); | ||
|
||
int numChannels = static_cast<int>(wf_.sizes()[0]); | ||
TORCH_CHECK( | ||
|
@@ -120,12 +118,6 @@ AudioEncoder::AudioEncoder( | |
"avcodec_open2 failed: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
TORCH_CHECK( | ||
avCodecContext_->frame_size > 0, | ||
"frame_size is ", | ||
avCodecContext_->frame_size, | ||
". Cannot encode. This should probably never happen?"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't always be non-zero, see below. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting - it might be worth noting why from the docs (https://ffmpeg.org/doxygen/6.0/structAVCodecContext.html#aec57f0d859a6df8b479cd93ca3a44a33, which I admit to not understanding) when we turn 0 into our default. |
||
|
||
// We're allocating the stream here. Streams are meant to be freed by | ||
// avformat_free_context(avFormatContext), which we call in the | ||
// avFormatContext_'s destructor. | ||
|
@@ -140,11 +132,50 @@ AudioEncoder::AudioEncoder( | |
streamIndex_ = avStream->index; | ||
} | ||
|
||
AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) { | ||
// Find a sample format that the encoder supports. We prefer using FLT[P], | ||
// since this is the format of the input waveform. If FLTP isn't supported | ||
// then we'll need to convert the AVFrame's format. Our heuristic is to encode | ||
// into the format with the highest resolution. | ||
if (avCodec.sample_fmts == nullptr) { | ||
// Can't really validate anything in this case, best we can do is hope that | ||
// FLTP is supported by the encoder. If not, FFmpeg will raise. | ||
return AV_SAMPLE_FMT_FLTP; | ||
} | ||
|
||
std::vector<AVSampleFormat> preferredFormatsOrder = { | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
AV_SAMPLE_FMT_FLTP, | ||
AV_SAMPLE_FMT_FLT, | ||
AV_SAMPLE_FMT_DBLP, | ||
AV_SAMPLE_FMT_DBL, | ||
AV_SAMPLE_FMT_S64P, | ||
AV_SAMPLE_FMT_S64, | ||
AV_SAMPLE_FMT_S32P, | ||
AV_SAMPLE_FMT_S32, | ||
AV_SAMPLE_FMT_S16P, | ||
AV_SAMPLE_FMT_S16, | ||
AV_SAMPLE_FMT_U8P, | ||
AV_SAMPLE_FMT_U8}; | ||
|
||
for (AVSampleFormat preferredFormat : preferredFormatsOrder) { | ||
for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) { | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (avCodec.sample_fmts[i] == preferredFormat) { | ||
return preferredFormat; | ||
} | ||
} | ||
} | ||
// Should never happen, but just in case | ||
return avCodec.sample_fmts[0]; | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
void AudioEncoder::encode() { | ||
UniqueAVFrame avFrame(av_frame_alloc()); | ||
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); | ||
avFrame->nb_samples = avCodecContext_->frame_size; | ||
avFrame->format = avCodecContext_->sample_fmt; | ||
// Default to 256 like in torchaudio | ||
int numSamplesAllocatedPerFrame = | ||
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256; | ||
avFrame->nb_samples = numSamplesAllocatedPerFrame; | ||
avFrame->format = AV_SAMPLE_FMT_FLTP; | ||
avFrame->sample_rate = avCodecContext_->sample_rate; | ||
avFrame->pts = 0; | ||
setChannelLayout(avFrame, avCodecContext_); | ||
|
@@ -160,7 +191,6 @@ void AudioEncoder::encode() { | |
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr()); | ||
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel | ||
int numEncodedSamples = 0; // per channel | ||
int numSamplesPerFrame = avCodecContext_->frame_size; // per channel | ||
int numBytesPerSample = static_cast<int>(wf_.element_size()); | ||
int numBytesPerChannel = numSamples * numBytesPerSample; | ||
|
||
|
@@ -178,7 +208,7 @@ void AudioEncoder::encode() { | |
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
int numSamplesToEncode = | ||
std::min(numSamplesPerFrame, numSamples - numEncodedSamples); | ||
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples); | ||
int numBytesToEncode = numSamplesToEncode * numBytesPerSample; | ||
|
||
for (int ch = 0; ch < wf_.sizes()[0]; ch++) { | ||
|
@@ -211,7 +241,37 @@ void AudioEncoder::encode() { | |
|
||
void AudioEncoder::encodeInnerLoop( | ||
AutoAVPacket& autoAVPacket, | ||
const UniqueAVFrame& avFrame) { | ||
const UniqueAVFrame& srcAVFrame) { | ||
bool mustConvert = | ||
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP && | ||
srcAVFrame != nullptr); | ||
UniqueAVFrame convertedAVFrame; | ||
if (mustConvert) { | ||
if (!swrContext_) { | ||
swrContext_.reset(createSwrContext( | ||
avCodecContext_, | ||
AV_SAMPLE_FMT_FLTP, | ||
avCodecContext_->sample_fmt, | ||
srcAVFrame->sample_rate, // No sample rate conversion | ||
srcAVFrame->sample_rate)); | ||
} | ||
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( | ||
swrContext_, | ||
srcAVFrame, | ||
avCodecContext_->sample_fmt, | ||
srcAVFrame->sample_rate, // No sample rate conversion | ||
srcAVFrame->sample_rate); | ||
TORCH_CHECK( | ||
convertedAVFrame->nb_samples == srcAVFrame->nb_samples, | ||
"convertedAVFrame->nb_samples=", | ||
convertedAVFrame->nb_samples, | ||
" differs from ", | ||
"srcAVFrame->nb_samples=", | ||
srcAVFrame->nb_samples, | ||
"This is unexpected, please report on the TorchCodec bug tracker."); | ||
} | ||
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; | ||
|
||
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); | ||
TORCH_CHECK( | ||
status == AVSUCCESS, | ||
|
@@ -248,6 +308,9 @@ void AudioEncoder::encodeInnerLoop( | |
} | ||
|
||
void AudioEncoder::flushBuffers() { | ||
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing | ||
// swresample is only necessary when converting sample rates, which we don't | ||
// do for encoding. | ||
AutoAVPacket autoAVPacket; | ||
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr)); | ||
} | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -44,6 +44,7 @@ | |||||||||||||||||
from .utils import ( | ||||||||||||||||||
assert_frames_equal, | ||||||||||||||||||
cpu_and_cuda, | ||||||||||||||||||
get_ffmpeg_major_version, | ||||||||||||||||||
in_fbcode, | ||||||||||||||||||
NASA_AUDIO, | ||||||||||||||||||
NASA_AUDIO_MP3, | ||||||||||||||||||
|
@@ -1122,33 +1123,40 @@ def test_bad_input(self, tmp_path): | |||||||||||||||||
bit_rate=-1, # bad | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
def test_round_trip(self, tmp_path): | ||||||||||||||||||
# Check that decode(encode(samples)) == samples | ||||||||||||||||||
@pytest.mark.parametrize("output_format", ("wav", "flac")) | ||||||||||||||||||
def test_round_trip(self, output_format, tmp_path): | ||||||||||||||||||
# Check that decode(encode(samples)) == samples on lossless formats | ||||||||||||||||||
|
||||||||||||||||||
if get_ffmpeg_major_version() == 4 and output_format == "wav": | ||||||||||||||||||
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, we're hitting this: torchcodec/src/torchcodec/_core/FFMPEGCommon.cpp Lines 159 to 166 in 12cdaa8
|
||||||||||||||||||
|
||||||||||||||||||
asset = NASA_AUDIO_MP3 | ||||||||||||||||||
source_samples = self.decode(asset) | ||||||||||||||||||
|
||||||||||||||||||
encoded_path = tmp_path / "output.mp3" | ||||||||||||||||||
encoded_path = tmp_path / f"output.{output_format}" | ||||||||||||||||||
encoder = create_audio_encoder( | ||||||||||||||||||
wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path) | ||||||||||||||||||
) | ||||||||||||||||||
encode_audio(encoder) | ||||||||||||||||||
|
||||||||||||||||||
# TODO-ENCODING: tol should be stricter. We probably need to encode | ||||||||||||||||||
# into a lossless format. | ||||||||||||||||||
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) | ||||||||||||||||||
torch.testing.assert_close( | ||||||||||||||||||
self.decode(encoded_path), source_samples, rtol=0, atol=0.07 | ||||||||||||||||||
self.decode(encoded_path), source_samples, rtol=rtol, atol=atol | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
# TODO-ENCODING: test more encoding formats | ||||||||||||||||||
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") | ||||||||||||||||||
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) | ||||||||||||||||||
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) | ||||||||||||||||||
def test_against_cli(self, asset, bit_rate, tmp_path): | ||||||||||||||||||
@pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) | ||||||||||||||||||
def test_against_cli(self, asset, bit_rate, output_format, tmp_path): | ||||||||||||||||||
# Encodes samples with our encoder and with the FFmpeg CLI, and checks | ||||||||||||||||||
# that both decoded outputs are equal | ||||||||||||||||||
|
||||||||||||||||||
encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3" | ||||||||||||||||||
encoded_by_us = tmp_path / "our_output.mp3" | ||||||||||||||||||
if get_ffmpeg_major_version() == 4 and output_format == "wav": | ||||||||||||||||||
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") | ||||||||||||||||||
|
||||||||||||||||||
encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}" | ||||||||||||||||||
encoded_by_us = tmp_path / f"our_output.{output_format}" | ||||||||||||||||||
|
||||||||||||||||||
subprocess.run( | ||||||||||||||||||
["ffmpeg", "-i", str(asset.path)] | ||||||||||||||||||
|
@@ -1168,8 +1176,12 @@ def test_against_cli(self, asset, bit_rate, tmp_path): | |||||||||||||||||
) | ||||||||||||||||||
encode_audio(encoder) | ||||||||||||||||||
|
||||||||||||||||||
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) | ||||||||||||||||||
torch.testing.assert_close( | ||||||||||||||||||
self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us) | ||||||||||||||||||
self.decode(encoded_by_ffmpeg), | ||||||||||||||||||
self.decode(encoded_by_us), | ||||||||||||||||||
rtol=rtol, | ||||||||||||||||||
atol=atol, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My original comment was wrong: that's not the format of the input waveform. It's the format of the input AVFrame that we pass to
avcodec_send_frame()
. And it needs to be a format that the codec supports.