diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 99628f90..08af3402 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -33,6 +33,44 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) { supportedRates.str()); } +static const std::vector preferredFormatsOrder = { + AV_SAMPLE_FMT_FLTP, + AV_SAMPLE_FMT_FLT, + AV_SAMPLE_FMT_DBLP, + AV_SAMPLE_FMT_DBL, + AV_SAMPLE_FMT_S64P, + AV_SAMPLE_FMT_S64, + AV_SAMPLE_FMT_S32P, + AV_SAMPLE_FMT_S32, + AV_SAMPLE_FMT_S16P, + AV_SAMPLE_FMT_S16, + AV_SAMPLE_FMT_U8P, + AV_SAMPLE_FMT_U8}; + +AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) { + // Find a sample format that the encoder supports. We prefer using FLT[P], + // since this is the format of the input waveform. If FLTP isn't supported + // then we'll need to convert the AVFrame's format. Our heuristic is to encode + // into the format with the highest resolution. + if (avCodec.sample_fmts == nullptr) { + // Can't really validate anything in this case, best we can do is hope that + // FLTP is supported by the encoder. If not, FFmpeg will raise. + return AV_SAMPLE_FMT_FLTP; + } + + for (AVSampleFormat preferredFormat : preferredFormatsOrder) { + for (int i = 0; avCodec.sample_fmts[i] != -1; ++i) { + if (avCodec.sample_fmts[i] == preferredFormat) { + return preferredFormat; + } + } + } + // We should always find a match in preferredFormatsOrder, so we should always + // return earlier. But in the event that a future FFmpeg version defines an + // additional sample format that isn't in preferredFormatsOrder, we fallback: + return avCodec.sample_fmts[0]; +} + } // namespace AudioEncoder::~AudioEncoder() {} @@ -47,6 +85,8 @@ AudioEncoder::AudioEncoder( wf_.dtype() == torch::kFloat32, "waveform must have float32 dtype, got ", wf_.dtype()); + // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed + // planar (fltp). TORCH_CHECK( wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim()); @@ -92,14 +132,10 @@ AudioEncoder::AudioEncoder( validateSampleRate(*avCodec, sampleRate); avCodecContext_->sample_rate = sampleRate; - // Note: This is the format of the **input** waveform. This doesn't determine - // the output. - // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed - // planar. - // TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will - // raise. We need to handle this, probably converting the format with - // libswresample. - avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; + // Input waveform is expected to be FLTP. Not all encoders support FLTP, so we + // may need to convert the wf into a supported output sample format, which is + // what the `.sample_fmt` defines. + avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); int numChannels = static_cast(wf_.sizes()[0]); TORCH_CHECK( @@ -120,12 +156,6 @@ AudioEncoder::AudioEncoder( "avcodec_open2 failed: ", getFFMPEGErrorStringFromErrorCode(status)); - TORCH_CHECK( - avCodecContext_->frame_size > 0, - "frame_size is ", - avCodecContext_->frame_size, - ". Cannot encode. This should probably never happen?"); - // We're allocating the stream here. Streams are meant to be freed by // avformat_free_context(avFormatContext), which we call in the // avFormatContext_'s destructor. @@ -143,8 +173,11 @@ AudioEncoder::AudioEncoder( void AudioEncoder::encode() { UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); - avFrame->nb_samples = avCodecContext_->frame_size; - avFrame->format = avCodecContext_->sample_fmt; + // Default to 256 like in torchaudio + int numSamplesAllocatedPerFrame = + avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256; + avFrame->nb_samples = numSamplesAllocatedPerFrame; + avFrame->format = AV_SAMPLE_FMT_FLTP; avFrame->sample_rate = avCodecContext_->sample_rate; avFrame->pts = 0; setChannelLayout(avFrame, avCodecContext_); @@ -160,7 +193,6 @@ void AudioEncoder::encode() { uint8_t* pwf = static_cast(wf_.data_ptr()); int numSamples = static_cast(wf_.sizes()[1]); // per channel int numEncodedSamples = 0; // per channel - int numSamplesPerFrame = avCodecContext_->frame_size; // per channel int numBytesPerSample = static_cast(wf_.element_size()); int numBytesPerChannel = numSamples * numBytesPerSample; @@ -178,7 +210,7 @@ void AudioEncoder::encode() { getFFMPEGErrorStringFromErrorCode(status)); int numSamplesToEncode = - std::min(numSamplesPerFrame, numSamples - numEncodedSamples); + std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples); int numBytesToEncode = numSamplesToEncode * numBytesPerSample; for (int ch = 0; ch < wf_.sizes()[0]; ch++) { @@ -211,7 +243,37 @@ void AudioEncoder::encode() { void AudioEncoder::encodeInnerLoop( AutoAVPacket& autoAVPacket, - const UniqueAVFrame& avFrame) { + const UniqueAVFrame& srcAVFrame) { + bool mustConvert = + (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP && + srcAVFrame != nullptr); + UniqueAVFrame convertedAVFrame; + if (mustConvert) { + if (!swrContext_) { + swrContext_.reset(createSwrContext( + avCodecContext_, + AV_SAMPLE_FMT_FLTP, + avCodecContext_->sample_fmt, + srcAVFrame->sample_rate, // No sample rate conversion + srcAVFrame->sample_rate)); + } + convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( + swrContext_, + srcAVFrame, + avCodecContext_->sample_fmt, + srcAVFrame->sample_rate, // No sample rate conversion + srcAVFrame->sample_rate); + TORCH_CHECK( + convertedAVFrame->nb_samples == srcAVFrame->nb_samples, + "convertedAVFrame->nb_samples=", + convertedAVFrame->nb_samples, + " differs from ", + "srcAVFrame->nb_samples=", + srcAVFrame->nb_samples, + "This is unexpected, please report on the TorchCodec bug tracker."); + } + const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; + auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); TORCH_CHECK( status == AVSUCCESS, @@ -248,6 +310,9 @@ void AudioEncoder::encodeInnerLoop( } void AudioEncoder::flushBuffers() { + // We flush the main FFmpeg buffers, but not swresample buffers. Flushing + // swresample is only necessary when converting sample rates, which we don't + // do for encoding. AutoAVPacket autoAVPacket; encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr)); } diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index cccab2b2..3e1abeac 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -26,12 +26,13 @@ class AudioEncoder { private: void encodeInnerLoop( AutoAVPacket& autoAVPacket, - const UniqueAVFrame& avFrame); + const UniqueAVFrame& srcAVFrame); void flushBuffers(); UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; int streamIndex_; + UniqueSwrContext swrContext_; const torch::Tensor wf_; }; diff --git a/test/test_ops.py b/test/test_ops.py index fce7a3d8..781a8724 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -44,6 +44,7 @@ from .utils import ( assert_frames_equal, cpu_and_cuda, + get_ffmpeg_major_version, in_fbcode, NASA_AUDIO, NASA_AUDIO_MP3, @@ -1122,33 +1123,40 @@ def test_bad_input(self, tmp_path): bit_rate=-1, # bad ) - def test_round_trip(self, tmp_path): - # Check that decode(encode(samples)) == samples + @pytest.mark.parametrize("output_format", ("wav", "flac")) + def test_round_trip(self, output_format, tmp_path): + # Check that decode(encode(samples)) == samples on lossless formats + + if get_ffmpeg_major_version() == 4 and output_format == "wav": + pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") + asset = NASA_AUDIO_MP3 source_samples = self.decode(asset) - encoded_path = tmp_path / "output.mp3" + encoded_path = tmp_path / f"output.{output_format}" encoder = create_audio_encoder( wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path) ) encode_audio(encoder) - # TODO-ENCODING: tol should be stricter. We probably need to encode - # into a lossless format. + rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) torch.testing.assert_close( - self.decode(encoded_path), source_samples, rtol=0, atol=0.07 + self.decode(encoded_path), source_samples, rtol=rtol, atol=atol ) - # TODO-ENCODING: test more encoding formats @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) - def test_against_cli(self, asset, bit_rate, tmp_path): + @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) + def test_against_cli(self, asset, bit_rate, output_format, tmp_path): # Encodes samples with our encoder and with the FFmpeg CLI, and checks # that both decoded outputs are equal - encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3" - encoded_by_us = tmp_path / "our_output.mp3" + if get_ffmpeg_major_version() == 4 and output_format == "wav": + pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") + + encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}" + encoded_by_us = tmp_path / f"our_output.{output_format}" subprocess.run( ["ffmpeg", "-i", str(asset.path)] @@ -1168,8 +1176,12 @@ def test_against_cli(self, asset, bit_rate, tmp_path): ) encode_audio(encoder) + rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) torch.testing.assert_close( - self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us) + self.decode(encoded_by_ffmpeg), + self.decode(encoded_by_us), + rtol=rtol, + atol=atol, )