From 79215589a1e771423ceeeab54da604d22a7fbff5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Apr 2025 16:15:54 +0100 Subject: [PATCH 01/14] Disable FFmpeg logs for encoder --- src/torchcodec/_core/Encoder.cpp | 4 +-- src/torchcodec/_core/FFMPEGCommon.cpp | 34 ++++++++++++++++++++ src/torchcodec/_core/FFMPEGCommon.h | 2 ++ src/torchcodec/_core/SingleStreamDecoder.cpp | 34 -------------------- src/torchcodec/_core/SingleStreamDecoder.h | 1 - 5 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9d5c1dea..b397a16d 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -5,8 +5,6 @@ namespace facebook::torchcodec { AudioEncoder::~AudioEncoder() {} -// TODO-ENCODING: disable ffmpeg logs by default - AudioEncoder::AudioEncoder( const torch::Tensor wf, int sampleRate, @@ -18,6 +16,8 @@ AudioEncoder::AudioEncoder( wf_.dtype()); TORCH_CHECK( wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim()); + + setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; auto status = avformat_alloc_output_context2( &avFormatContext, nullptr, nullptr, fileName.data()); diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 64e4da70..8a3116d3 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/_core/FFMPEGCommon.h" +#include #include @@ -158,4 +159,37 @@ SwrContext* allocateSwrContext( return swrContext; } +void setFFmpegLogLevel() { + auto logLevel = AV_LOG_QUIET; + const char* logLevelEnv = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); + if (logLevelEnv != nullptr) { + if (std::strcmp(logLevelEnv, "QUIET") == 0) { + logLevel = AV_LOG_QUIET; + } else if (std::strcmp(logLevelEnv, "PANIC") == 0) { + logLevel = AV_LOG_PANIC; + } else if (std::strcmp(logLevelEnv, "FATAL") == 0) { + logLevel = AV_LOG_FATAL; + } else if (std::strcmp(logLevelEnv, "ERROR") == 0) { + logLevel = AV_LOG_ERROR; + } else if (std::strcmp(logLevelEnv, "WARNING") == 0) { + logLevel = AV_LOG_WARNING; + } else if (std::strcmp(logLevelEnv, "INFO") == 0) { + logLevel = AV_LOG_INFO; + } else if (std::strcmp(logLevelEnv, "VERBOSE") == 0) { + logLevel = AV_LOG_VERBOSE; + } else if (std::strcmp(logLevelEnv, "DEBUG") == 0) { + logLevel = AV_LOG_DEBUG; + } else if (std::strcmp(logLevelEnv, "TRACE") == 0) { + logLevel = AV_LOG_TRACE; + } else { + TORCH_CHECK( + false, + "Invalid TORCHCODEC_FFMPEG_LOG_LEVEL: ", + logLevelEnv, + ". Use e.g. 'QUIET', 'PANIC', 'VERBOSE', etc."); + } + } + av_log_set_level(logLevel); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index fdb30962..81a9fb8f 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -168,4 +168,6 @@ SwrContext* allocateSwrContext( // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); +void setFFmpegLogLevel(); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index c7c714da..7ffdda7b 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -7,7 +7,6 @@ #include "src/torchcodec/_core/SingleStreamDecoder.h" #include #include -#include #include #include #include @@ -185,39 +184,6 @@ void SingleStreamDecoder::initializeDecoder() { initialized_ = true; } -void SingleStreamDecoder::setFFmpegLogLevel() { - auto logLevel = AV_LOG_QUIET; - const char* logLevelEnv = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); - if (logLevelEnv != nullptr) { - if (std::strcmp(logLevelEnv, "QUIET") == 0) { - logLevel = AV_LOG_QUIET; - } else if (std::strcmp(logLevelEnv, "PANIC") == 0) { - logLevel = AV_LOG_PANIC; - } else if (std::strcmp(logLevelEnv, "FATAL") == 0) { - logLevel = AV_LOG_FATAL; - } else if (std::strcmp(logLevelEnv, "ERROR") == 0) { - logLevel = AV_LOG_ERROR; - } else if (std::strcmp(logLevelEnv, "WARNING") == 0) { - logLevel = AV_LOG_WARNING; - } else if (std::strcmp(logLevelEnv, "INFO") == 0) { - logLevel = AV_LOG_INFO; - } else if (std::strcmp(logLevelEnv, "VERBOSE") == 0) { - logLevel = AV_LOG_VERBOSE; - } else if (std::strcmp(logLevelEnv, "DEBUG") == 0) { - logLevel = AV_LOG_DEBUG; - } else if (std::strcmp(logLevelEnv, "TRACE") == 0) { - logLevel = AV_LOG_TRACE; - } else { - TORCH_CHECK( - false, - "Invalid TORCHCODEC_FFMPEG_LOG_LEVEL: ", - logLevelEnv, - ". Use e.g. 'QUIET', 'PANIC', 'VERBOSE', etc."); - } - } - av_log_set_level(logLevel); -} - int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) { AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; int streamIndex = diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 4879a3b7..15548bb5 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -363,7 +363,6 @@ class SingleStreamDecoder { // -------------------------------------------------------------------------- void initializeDecoder(); - void setFFmpegLogLevel(); // -------------------------------------------------------------------------- // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- From 73bdc85ca3871412c6982d1606238f8f46d20830 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 8 Apr 2025 11:14:38 +0100 Subject: [PATCH 02/14] Use c++ strings --- src/torchcodec/_core/FFMPEGCommon.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 8a3116d3..aad3c23c 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -5,7 +5,6 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/_core/FFMPEGCommon.h" -#include #include @@ -161,25 +160,26 @@ SwrContext* allocateSwrContext( void setFFmpegLogLevel() { auto logLevel = AV_LOG_QUIET; - const char* logLevelEnv = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); - if (logLevelEnv != nullptr) { - if (std::strcmp(logLevelEnv, "QUIET") == 0) { + const char* logLevelEnvPtr = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); + if (logLevelEnvPtr != nullptr) { + std::string logLevelEnv(logLevelEnvPtr); + if (logLevelEnv == "QUIET") { logLevel = AV_LOG_QUIET; - } else if (std::strcmp(logLevelEnv, "PANIC") == 0) { + } else if (logLevelEnv == "PANIC") { logLevel = AV_LOG_PANIC; - } else if (std::strcmp(logLevelEnv, "FATAL") == 0) { + } else if (logLevelEnv == "FATAL") { logLevel = AV_LOG_FATAL; - } else if (std::strcmp(logLevelEnv, "ERROR") == 0) { + } else if (logLevelEnv == "ERROR") { logLevel = AV_LOG_ERROR; - } else if (std::strcmp(logLevelEnv, "WARNING") == 0) { + } else if (logLevelEnv == "WARNING") { logLevel = AV_LOG_WARNING; - } else if (std::strcmp(logLevelEnv, "INFO") == 0) { + } else if (logLevelEnv == "INFO") { logLevel = AV_LOG_INFO; - } else if (std::strcmp(logLevelEnv, "VERBOSE") == 0) { + } else if (logLevelEnv == "VERBOSE") { logLevel = AV_LOG_VERBOSE; - } else if (std::strcmp(logLevelEnv, "DEBUG") == 0) { + } else if (logLevelEnv == "DEBUG") { logLevel = AV_LOG_DEBUG; - } else if (std::strcmp(logLevelEnv, "TRACE") == 0) { + } else if (logLevelEnv == "TRACE") { logLevel = AV_LOG_TRACE; } else { TORCH_CHECK( From 24842b6cbf071261f3a9e84bce186dc99b1b33bc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 8 Apr 2025 13:52:53 +0100 Subject: [PATCH 03/14] Account for frame_size being 0 --- src/torchcodec/_core/Encoder.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 711be8b7..23737a84 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -100,6 +100,7 @@ AudioEncoder::AudioEncoder( // raise. We need to handle this, probably converting the format with // libswresample. avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; + // avCodecContext_->sample_fmt = AV_SAMPLE_FMT_S16; int numChannels = static_cast(wf_.sizes()[0]); TORCH_CHECK( @@ -120,12 +121,6 @@ AudioEncoder::AudioEncoder( "avcodec_open2 failed: ", getFFMPEGErrorStringFromErrorCode(status)); - TORCH_CHECK( - avCodecContext_->frame_size > 0, - "frame_size is ", - avCodecContext_->frame_size, - ". Cannot encode. This should probably never happen?"); - // We're allocating the stream here. Streams are meant to be freed by // avformat_free_context(avFormatContext), which we call in the // avFormatContext_'s destructor. @@ -143,7 +138,10 @@ AudioEncoder::AudioEncoder( void AudioEncoder::encode() { UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); - avFrame->nb_samples = avCodecContext_->frame_size; + // Default to 256 like in torchaudio + int numSamplesAllocatedPerFrame = + avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256; + avFrame->nb_samples = numSamplesAllocatedPerFrame; avFrame->format = avCodecContext_->sample_fmt; avFrame->sample_rate = avCodecContext_->sample_rate; avFrame->pts = 0; @@ -160,7 +158,6 @@ void AudioEncoder::encode() { uint8_t* pwf = static_cast(wf_.data_ptr()); int numSamples = static_cast(wf_.sizes()[1]); // per channel int numEncodedSamples = 0; // per channel - int numSamplesPerFrame = avCodecContext_->frame_size; // per channel int numBytesPerSample = static_cast(wf_.element_size()); int numBytesPerChannel = numSamples * numBytesPerSample; @@ -178,7 +175,7 @@ void AudioEncoder::encode() { getFFMPEGErrorStringFromErrorCode(status)); int numSamplesToEncode = - std::min(numSamplesPerFrame, numSamples - numEncodedSamples); + std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples); int numBytesToEncode = numSamplesToEncode * numBytesPerSample; for (int ch = 0; ch < wf_.sizes()[0]; ch++) { From 5b39c8f3f7d18f1463e1bdfa871f647b6b8dd4a7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 10:46:34 +0100 Subject: [PATCH 04/14] WIP --- src/torchcodec/_core/SingleStreamDecoder.cpp | 15 ++++++++------- src/torchcodec/_core/SingleStreamDecoder.h | 4 ++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 02ea5ee1..53860e1a 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1402,12 +1402,12 @@ UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( auto& streamInfo = streamInfos_[activeStreamIndex_]; if (!streamInfo.swrContext) { - createSwrContext( - streamInfo, + streamInfo.swrContext.reset(createSwrContext( + streamInfo.codecContext, sourceSampleFormat, desiredSampleFormat, sourceSampleRate, - desiredSampleRate); + desiredSampleRate)); } UniqueAVFrame convertedAVFrame(av_frame_alloc()); @@ -1735,14 +1735,14 @@ void SingleStreamDecoder::createSwsContext( streamInfo.swsContext.reset(swsContext); } -void SingleStreamDecoder::createSwrContext( - StreamInfo& streamInfo, +SwrContext* SingleStreamDecoder::createSwrContext( + UniqueAVCodecContext& avCodecContext, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, int desiredSampleRate) { auto swrContext = allocateSwrContext( - streamInfo.codecContext, + avCodecContext, sourceSampleFormat, desiredSampleFormat, sourceSampleRate, @@ -1756,7 +1756,8 @@ void SingleStreamDecoder::createSwrContext( ". If the error says 'Invalid argument', it's likely that you are using " "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " "valid scenarios. Try to upgrade FFmpeg?"); - streamInfo.swrContext.reset(swrContext); + // streamInfo.swrContext.reset(swrContext); + return swrContext; } // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index d532675e..785e1ccf 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -310,8 +310,8 @@ class SingleStreamDecoder { const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace); - void createSwrContext( - StreamInfo& streamInfo, + SwrContext* createSwrContext( + UniqueAVCodecContext& avCodecContext, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, From 1f9f904fbc83b5538daf22ea938987d7aab1d1ba Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 10:52:47 +0100 Subject: [PATCH 05/14] Move createSwrContext in ffmpeg file --- src/torchcodec/_core/FFMPEGCommon.cpp | 13 ++++++++-- src/torchcodec/_core/FFMPEGCommon.h | 2 +- src/torchcodec/_core/SingleStreamDecoder.cpp | 25 -------------------- src/torchcodec/_core/SingleStreamDecoder.h | 7 ------ 4 files changed, 12 insertions(+), 35 deletions(-) diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index aad3c23c..341822e7 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -116,16 +116,17 @@ void setChannelLayout( #endif } -SwrContext* allocateSwrContext( +SwrContext* createSwrContext( UniqueAVCodecContext& avCodecContext, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, int desiredSampleRate) { SwrContext* swrContext = nullptr; + int status = AVSUCCESS; #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 AVChannelLayout layout = avCodecContext->ch_layout; - auto status = swr_alloc_set_opts2( + status = swr_alloc_set_opts2( &swrContext, &layout, desiredSampleFormat, @@ -155,6 +156,14 @@ SwrContext* allocateSwrContext( #endif TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext"); + status = swr_init(swrContext); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't initialize SwrContext: ", + getFFMPEGErrorStringFromErrorCode(status), + ". If the error says 'Invalid argument', it's likely that you are using " + "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " + "valid scenarios. Try to upgrade FFmpeg?"); return swrContext; } diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 81a9fb8f..3ebb4291 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -158,7 +158,7 @@ void setChannelLayout( void setChannelLayout( UniqueAVFrame& dstAVFrame, const UniqueAVFrame& srcAVFrame); -SwrContext* allocateSwrContext( +SwrContext* createSwrContext( UniqueAVCodecContext& avCodecContext, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 53860e1a..03b0a787 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1735,31 +1735,6 @@ void SingleStreamDecoder::createSwsContext( streamInfo.swsContext.reset(swsContext); } -SwrContext* SingleStreamDecoder::createSwrContext( - UniqueAVCodecContext& avCodecContext, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate) { - auto swrContext = allocateSwrContext( - avCodecContext, - sourceSampleFormat, - desiredSampleFormat, - sourceSampleRate, - desiredSampleRate); - - auto status = swr_init(swrContext); - TORCH_CHECK( - status == AVSUCCESS, - "Couldn't initialize SwrContext: ", - getFFMPEGErrorStringFromErrorCode(status), - ". If the error says 'Invalid argument', it's likely that you are using " - "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " - "valid scenarios. Try to upgrade FFmpeg?"); - // streamInfo.swrContext.reset(swrContext); - return swrContext; -} - // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 785e1ccf..325edb0e 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -310,13 +310,6 @@ class SingleStreamDecoder { const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace); - SwrContext* createSwrContext( - UniqueAVCodecContext& avCodecContext, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate); - // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- From f525848b95d5459e19ccc8227d2c9541ae1384e0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 11:04:28 +0100 Subject: [PATCH 06/14] WIP --- src/torchcodec/_core/SingleStreamDecoder.cpp | 32 +++++++++----------- src/torchcodec/_core/SingleStreamDecoder.h | 2 +- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 03b0a787..ea173308 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1345,10 +1345,10 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( static_cast(srcAVFrame->format); AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP; + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; int sourceSampleRate = srcAVFrame->sample_rate; int desiredSampleRate = - streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or( - sourceSampleRate); + streamInfo.audioStreamOptions.sampleRate.value_or(sourceSampleRate); bool mustConvert = (sourceSampleFormat != desiredSampleFormat || @@ -1356,9 +1356,18 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( UniqueAVFrame convertedAVFrame; if (mustConvert) { + if (!streamInfo.swrContext) { + streamInfo.swrContext.reset(createSwrContext( + streamInfo.codecContext, + sourceSampleFormat, + desiredSampleFormat, + sourceSampleRate, + desiredSampleRate)); + } + convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( + streamInfo.swrContext, srcAVFrame, - sourceSampleFormat, desiredSampleFormat, sourceSampleRate, desiredSampleRate); @@ -1394,22 +1403,11 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( } UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( + const UniqueSwrContext& swrContext, const UniqueAVFrame& srcAVFrame, - AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, int desiredSampleRate) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; - - if (!streamInfo.swrContext) { - streamInfo.swrContext.reset(createSwrContext( - streamInfo.codecContext, - sourceSampleFormat, - desiredSampleFormat, - sourceSampleRate, - desiredSampleRate)); - } - UniqueAVFrame convertedAVFrame(av_frame_alloc()); TORCH_CHECK( convertedAVFrame, @@ -1428,7 +1426,7 @@ UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( // output samples, but empirically `av_rescale_rnd()` seems to provide a // tighter bound. convertedAVFrame->nb_samples = av_rescale_rnd( - swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) + + swr_get_delay(swrContext.get(), sourceSampleRate) + srcAVFrame->nb_samples, desiredSampleRate, sourceSampleRate, @@ -1444,7 +1442,7 @@ UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( getFFMPEGErrorStringFromErrorCode(status)); auto numConvertedSamples = swr_convert( - streamInfo.swrContext.get(), + swrContext.get(), convertedAVFrame->data, convertedAVFrame->nb_samples, static_cast( diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 325edb0e..7eea35f3 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -288,8 +288,8 @@ class SingleStreamDecoder { torch::Tensor& outputTensor); UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( + const UniqueSwrContext& swrContext, const UniqueAVFrame& srcAVFrame, - AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, int desiredSampleRate); From 9150137cdda0dada519d49dcd2c4e535507baa7e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 11:09:37 +0100 Subject: [PATCH 07/14] Move convertAudioAVFrameSampleFormatAndSampleRate in ffmpeg file --- src/torchcodec/_core/FFMPEGCommon.cpp | 60 ++++++++++++++++++++ src/torchcodec/_core/FFMPEGCommon.h | 7 +++ src/torchcodec/_core/SingleStreamDecoder.cpp | 60 -------------------- src/torchcodec/_core/SingleStreamDecoder.h | 7 --- 4 files changed, 67 insertions(+), 67 deletions(-) diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 341822e7..19722108 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -167,6 +167,66 @@ SwrContext* createSwrContext( return swrContext; } +UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( + const UniqueSwrContext& swrContext, + const UniqueAVFrame& srcAVFrame, + AVSampleFormat desiredSampleFormat, + int sourceSampleRate, + int desiredSampleRate) { + UniqueAVFrame convertedAVFrame(av_frame_alloc()); + TORCH_CHECK( + convertedAVFrame, + "Could not allocate frame for sample format conversion."); + + setChannelLayout(convertedAVFrame, srcAVFrame); + convertedAVFrame->format = static_cast(desiredSampleFormat); + convertedAVFrame->sample_rate = desiredSampleRate; + if (sourceSampleRate != desiredSampleRate) { + // Note that this is an upper bound on the number of output samples. + // `swr_convert()` will likely not fill convertedAVFrame with that many + // samples if sample rate conversion is needed. It will buffer the last few + // ones because those require future samples. That's also why we reset + // nb_samples after the call to `swr_convert()`. + // We could also use `swr_get_out_samples()` to determine the number of + // output samples, but empirically `av_rescale_rnd()` seems to provide a + // tighter bound. + convertedAVFrame->nb_samples = av_rescale_rnd( + swr_get_delay(swrContext.get(), sourceSampleRate) + + srcAVFrame->nb_samples, + desiredSampleRate, + sourceSampleRate, + AV_ROUND_UP); + } else { + convertedAVFrame->nb_samples = srcAVFrame->nb_samples; + } + + auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); + TORCH_CHECK( + status == AVSUCCESS, + "Could not allocate frame buffers for sample format conversion: ", + getFFMPEGErrorStringFromErrorCode(status)); + + auto numConvertedSamples = swr_convert( + swrContext.get(), + convertedAVFrame->data, + convertedAVFrame->nb_samples, + static_cast( + const_cast(srcAVFrame->data)), + srcAVFrame->nb_samples); + // numConvertedSamples can be 0 if we're downsampling by a great factor and + // the first frame doesn't contain a lot of samples. It should be handled + // properly by the caller. + TORCH_CHECK( + numConvertedSamples >= 0, + "Error in swr_convert: ", + getFFMPEGErrorStringFromErrorCode(numConvertedSamples)); + + // See comment above about nb_samples + convertedAVFrame->nb_samples = numConvertedSamples; + + return convertedAVFrame; +} + void setFFmpegLogLevel() { auto logLevel = AV_LOG_QUIET; const char* logLevelEnvPtr = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 3ebb4291..8c4abd13 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -165,6 +165,13 @@ SwrContext* createSwrContext( int sourceSampleRate, int desiredSampleRate); +UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( + const UniqueSwrContext& swrContext, + const UniqueAVFrame& srcAVFrame, + AVSampleFormat desiredSampleFormat, + int sourceSampleRate, + int desiredSampleRate); + // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index ea173308..17e1301d 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1402,66 +1402,6 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( } } -UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( - const UniqueSwrContext& swrContext, - const UniqueAVFrame& srcAVFrame, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate) { - UniqueAVFrame convertedAVFrame(av_frame_alloc()); - TORCH_CHECK( - convertedAVFrame, - "Could not allocate frame for sample format conversion."); - - setChannelLayout(convertedAVFrame, srcAVFrame); - convertedAVFrame->format = static_cast(desiredSampleFormat); - convertedAVFrame->sample_rate = desiredSampleRate; - if (sourceSampleRate != desiredSampleRate) { - // Note that this is an upper bound on the number of output samples. - // `swr_convert()` will likely not fill convertedAVFrame with that many - // samples if sample rate conversion is needed. It will buffer the last few - // ones because those require future samples. That's also why we reset - // nb_samples after the call to `swr_convert()`. - // We could also use `swr_get_out_samples()` to determine the number of - // output samples, but empirically `av_rescale_rnd()` seems to provide a - // tighter bound. - convertedAVFrame->nb_samples = av_rescale_rnd( - swr_get_delay(swrContext.get(), sourceSampleRate) + - srcAVFrame->nb_samples, - desiredSampleRate, - sourceSampleRate, - AV_ROUND_UP); - } else { - convertedAVFrame->nb_samples = srcAVFrame->nb_samples; - } - - auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); - TORCH_CHECK( - status == AVSUCCESS, - "Could not allocate frame buffers for sample format conversion: ", - getFFMPEGErrorStringFromErrorCode(status)); - - auto numConvertedSamples = swr_convert( - swrContext.get(), - convertedAVFrame->data, - convertedAVFrame->nb_samples, - static_cast( - const_cast(srcAVFrame->data)), - srcAVFrame->nb_samples); - // numConvertedSamples can be 0 if we're downsampling by a great factor and - // the first frame doesn't contain a lot of samples. It should be handled - // properly by the caller. - TORCH_CHECK( - numConvertedSamples >= 0, - "Error in swr_convert: ", - getFFMPEGErrorStringFromErrorCode(numConvertedSamples)); - - // See comment above about nb_samples - convertedAVFrame->nb_samples = numConvertedSamples; - - return convertedAVFrame; -} - std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // When sample rate conversion is involved, swresample buffers some of the // samples in-between calls to swr_convert (see the libswresample docs). diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 7eea35f3..d8515111 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -287,13 +287,6 @@ class SingleStreamDecoder { const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); - UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( - const UniqueSwrContext& swrContext, - const UniqueAVFrame& srcAVFrame, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate); - std::optional maybeFlushSwrBuffers(); // -------------------------------------------------------------------------- From 872b5691871a0e3a79914ce334e395a0122bbdce Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 11:59:02 +0100 Subject: [PATCH 08/14] Automatically find output sample format --- src/torchcodec/_core/Encoder.cpp | 38 +++++++++++++++++++++++++------- src/torchcodec/_core/Encoder.h | 1 + 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 23737a84..a9a50c74 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -92,15 +92,16 @@ AudioEncoder::AudioEncoder( validateSampleRate(*avCodec, sampleRate); avCodecContext_->sample_rate = sampleRate; - // Note: This is the format of the **input** waveform. This doesn't determine - // the output. + // Input waveform is expected to be FLTP. Not all encoders support FLTP, so we + // may need to convert the wf into a supported output sample format, which is + // what the `.sample_fmt` defines. + avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec); + printf( + "Will be using: %s\n", + av_get_sample_fmt_name(avCodecContext_->sample_fmt)); + // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed - // planar. - // TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will - // raise. We need to handle this, probably converting the format with - // libswresample. - avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; - // avCodecContext_->sample_fmt = AV_SAMPLE_FMT_S16; + // planar (fltp). int numChannels = static_cast(wf_.sizes()[0]); TORCH_CHECK( @@ -135,6 +136,27 @@ AudioEncoder::AudioEncoder( streamIndex_ = avStream->index; } +AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) { + // Find a sample format that the encoder supports. If FLTP is supported then + // we use that, since this is the expected format of the input waveform. + // Otherwise, we'll need to convert the waveform before passing it to the + // encoder. Right now, the output format we'll choose is just the first format + // in the `sample_fmts` list that the AVCodec defines. Eventually, we may + // allow the user to choose. + if (avCodec.sample_fmts == nullptr) { + // Can't really validate anything in this case, best we can do is hope that + // FLTP is supported by the encoder. If not, FFmpeg will raise. + return AV_SAMPLE_FMT_FLTP; + } + + for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) { + if (avCodec.sample_fmts[i] == AV_SAMPLE_FMT_FLTP) { + return AV_SAMPLE_FMT_FLTP; + } + } + return avCodec.sample_fmts[0]; +} + void AudioEncoder::encode() { UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 830c3552..02dcbf91 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -28,6 +28,7 @@ class AudioEncoder { AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); void flushBuffers(); + AVSampleFormat findOutputSampleFormat(const AVCodec& avCodec); UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; From a0dcafd829cc7fc80812d3c876008c335b4384cb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 12:24:36 +0100 Subject: [PATCH 09/14] Convert sample format, update tests --- src/torchcodec/_core/Encoder.cpp | 34 +++++++++++++++++++++++++++----- src/torchcodec/_core/Encoder.h | 3 ++- test/test_ops.py | 26 +++++++++++++----------- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index a9a50c74..9db23411 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -96,9 +96,6 @@ AudioEncoder::AudioEncoder( // may need to convert the wf into a supported output sample format, which is // what the `.sample_fmt` defines. avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec); - printf( - "Will be using: %s\n", - av_get_sample_fmt_name(avCodecContext_->sample_fmt)); // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed // planar (fltp). @@ -143,6 +140,8 @@ AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) { // encoder. Right now, the output format we'll choose is just the first format // in the `sample_fmts` list that the AVCodec defines. Eventually, we may // allow the user to choose. + // TODO-ENCODING: a better default would probably be to choose the highest + // available precision if (avCodec.sample_fmts == nullptr) { // Can't really validate anything in this case, best we can do is hope that // FLTP is supported by the encoder. If not, FFmpeg will raise. @@ -164,7 +163,7 @@ void AudioEncoder::encode() { int numSamplesAllocatedPerFrame = avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256; avFrame->nb_samples = numSamplesAllocatedPerFrame; - avFrame->format = avCodecContext_->sample_fmt; + avFrame->format = AV_SAMPLE_FMT_FLTP; avFrame->sample_rate = avCodecContext_->sample_rate; avFrame->pts = 0; setChannelLayout(avFrame, avCodecContext_); @@ -230,7 +229,29 @@ void AudioEncoder::encode() { void AudioEncoder::encodeInnerLoop( AutoAVPacket& autoAVPacket, - const UniqueAVFrame& avFrame) { + const UniqueAVFrame& srcAVFrame) { + bool mustConvert = + (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP && + srcAVFrame != nullptr); + UniqueAVFrame convertedAVFrame; + if (mustConvert) { + if (!swrContext_) { + swrContext_.reset(createSwrContext( + avCodecContext_, + AV_SAMPLE_FMT_FLTP, + avCodecContext_->sample_fmt, + srcAVFrame->sample_rate, // No sample rate conversion + srcAVFrame->sample_rate)); + } + convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( + swrContext_, + srcAVFrame, + avCodecContext_->sample_fmt, + srcAVFrame->sample_rate, // No sample rate conversion + srcAVFrame->sample_rate); + } + const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; + auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); TORCH_CHECK( status == AVSUCCESS, @@ -267,6 +288,9 @@ void AudioEncoder::encodeInnerLoop( } void AudioEncoder::flushBuffers() { + // We flush the main FFmpeg buffers, but not swresample buffers. Flushing + // swresample is only necessary when converting sample rates, which we don't + // do for encoding. AutoAVPacket autoAVPacket; encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr)); } diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 02dcbf91..e21aa7d8 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -26,13 +26,14 @@ class AudioEncoder { private: void encodeInnerLoop( AutoAVPacket& autoAVPacket, - const UniqueAVFrame& avFrame); + const UniqueAVFrame& srcAVFrame); void flushBuffers(); AVSampleFormat findOutputSampleFormat(const AVCodec& avCodec); UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; int streamIndex_; + UniqueSwrContext swrContext_; const torch::Tensor wf_; }; diff --git a/test/test_ops.py b/test/test_ops.py index fce7a3d8..3473ca01 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1122,34 +1122,35 @@ def test_bad_input(self, tmp_path): bit_rate=-1, # bad ) - def test_round_trip(self, tmp_path): - # Check that decode(encode(samples)) == samples + @pytest.mark.parametrize("output_format", ("wav", "flac")) + def test_round_trip(self, output_format, tmp_path): + # Check that decode(encode(samples)) == samples on lossless formats asset = NASA_AUDIO_MP3 source_samples = self.decode(asset) - encoded_path = tmp_path / "output.mp3" + encoded_path = tmp_path / f"output.{output_format}" encoder = create_audio_encoder( wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path) ) encode_audio(encoder) - # TODO-ENCODING: tol should be stricter. We probably need to encode - # into a lossless format. torch.testing.assert_close( - self.decode(encoded_path), source_samples, rtol=0, atol=0.07 + self.decode(encoded_path), source_samples, rtol=0, atol=1e-4 ) - # TODO-ENCODING: test more encoding formats @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) - def test_against_cli(self, asset, bit_rate, tmp_path): + @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) + def test_against_cli(self, asset, bit_rate, output_format, tmp_path): # Encodes samples with our encoder and with the FFmpeg CLI, and checks # that both decoded outputs are equal - encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3" - encoded_by_us = tmp_path / "our_output.mp3" + encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}" + encoded_by_us = tmp_path / f"our_output.{output_format}" + # Note: output format may be different from ours, e.g. FFmpeg CLI would + # choose s32 while our current heuristic may choose s16. subprocess.run( ["ffmpeg", "-i", str(asset.path)] + (["-b:a", f"{bit_rate}"] if bit_rate is not None else []) @@ -1169,7 +1170,10 @@ def test_against_cli(self, asset, bit_rate, tmp_path): encode_audio(encoder) torch.testing.assert_close( - self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us) + self.decode(encoded_by_ffmpeg), + self.decode(encoded_by_us), + rtol=0, + atol=1e-4, ) From f49d50756742f5afb0e18e2fa9329b5ecc5c0207 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 13:34:17 +0100 Subject: [PATCH 10/14] Skip wav on FFmpeg4 --- test/test_ops.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 3473ca01..c7a2540d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -44,6 +44,7 @@ from .utils import ( assert_frames_equal, cpu_and_cuda, + get_ffmpeg_major_version, in_fbcode, NASA_AUDIO, NASA_AUDIO_MP3, @@ -1125,6 +1126,10 @@ def test_bad_input(self, tmp_path): @pytest.mark.parametrize("output_format", ("wav", "flac")) def test_round_trip(self, output_format, tmp_path): # Check that decode(encode(samples)) == samples on lossless formats + + if get_ffmpeg_major_version() == 4 and output_format == "wav": + pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") + asset = NASA_AUDIO_MP3 source_samples = self.decode(asset) @@ -1146,6 +1151,9 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): # Encodes samples with our encoder and with the FFmpeg CLI, and checks # that both decoded outputs are equal + if get_ffmpeg_major_version() == 4 and output_format == "wav": + pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") + encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}" encoded_by_us = tmp_path / f"our_output.{output_format}" From ee3a199d1c93fe6b3a1a6385a2fe833a0c315476 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 13:40:16 +0100 Subject: [PATCH 11/14] Add assertion --- src/torchcodec/_core/Encoder.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9db23411..62a61a6c 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -249,6 +249,14 @@ void AudioEncoder::encodeInnerLoop( avCodecContext_->sample_fmt, srcAVFrame->sample_rate, // No sample rate conversion srcAVFrame->sample_rate); + TORCH_CHECK( + convertedAVFrame->nb_samples == srcAVFrame->nb_samples, + "convertedAVFrame->nb_samples=", + convertedAVFrame->nb_samples, + " differs from ", + "srcAVFrame->nb_samples=", + srcAVFrame->nb_samples, + "This is unexpected, please report on the TorchCodec bug tracker."); } const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; From 485ee2e9c5f419b717c47364a6166dc0b433e4e8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 13:49:04 +0100 Subject: [PATCH 12/14] Move comment --- src/torchcodec/_core/Encoder.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 62a61a6c..adac6edd 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -47,6 +47,8 @@ AudioEncoder::AudioEncoder( wf_.dtype() == torch::kFloat32, "waveform must have float32 dtype, got ", wf_.dtype()); + // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed + // planar (fltp). TORCH_CHECK( wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim()); @@ -97,9 +99,6 @@ AudioEncoder::AudioEncoder( // what the `.sample_fmt` defines. avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec); - // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed - // planar (fltp). - int numChannels = static_cast(wf_.sizes()[0]); TORCH_CHECK( // TODO-ENCODING is this even true / needed? We can probably support more From 27fdbac28b82df94f720add5fcbb42fc7adab280 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Apr 2025 14:40:57 +0100 Subject: [PATCH 13/14] Better default heuristic --- src/torchcodec/_core/Encoder.cpp | 35 ++++++++++++++++++++++---------- test/test_ops.py | 10 ++++----- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index adac6edd..a6c0a52a 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -133,25 +133,38 @@ AudioEncoder::AudioEncoder( } AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) { - // Find a sample format that the encoder supports. If FLTP is supported then - // we use that, since this is the expected format of the input waveform. - // Otherwise, we'll need to convert the waveform before passing it to the - // encoder. Right now, the output format we'll choose is just the first format - // in the `sample_fmts` list that the AVCodec defines. Eventually, we may - // allow the user to choose. - // TODO-ENCODING: a better default would probably be to choose the highest - // available precision + // Find a sample format that the encoder supports. We prefer using FLT[P], + // since this is the format of the input waveform. If FLTP isn't supported + // then we'll need to convert the AVFrame's format. Our heuristic is to encode + // into the format with the highest resolution. if (avCodec.sample_fmts == nullptr) { // Can't really validate anything in this case, best we can do is hope that // FLTP is supported by the encoder. If not, FFmpeg will raise. return AV_SAMPLE_FMT_FLTP; } - for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) { - if (avCodec.sample_fmts[i] == AV_SAMPLE_FMT_FLTP) { - return AV_SAMPLE_FMT_FLTP; + std::vector preferredFormatsOrder = { + AV_SAMPLE_FMT_FLTP, + AV_SAMPLE_FMT_FLT, + AV_SAMPLE_FMT_DBLP, + AV_SAMPLE_FMT_DBL, + AV_SAMPLE_FMT_S64P, + AV_SAMPLE_FMT_S64, + AV_SAMPLE_FMT_S32P, + AV_SAMPLE_FMT_S32, + AV_SAMPLE_FMT_S16P, + AV_SAMPLE_FMT_S16, + AV_SAMPLE_FMT_U8P, + AV_SAMPLE_FMT_U8}; + + for (AVSampleFormat preferredFormat : preferredFormatsOrder) { + for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) { + if (avCodec.sample_fmts[i] == preferredFormat) { + return preferredFormat; + } } } + // Should never happen, but just in case return avCodec.sample_fmts[0]; } diff --git a/test/test_ops.py b/test/test_ops.py index c7a2540d..781a8724 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1139,8 +1139,9 @@ def test_round_trip(self, output_format, tmp_path): ) encode_audio(encoder) + rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) torch.testing.assert_close( - self.decode(encoded_path), source_samples, rtol=0, atol=1e-4 + self.decode(encoded_path), source_samples, rtol=rtol, atol=atol ) @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") @@ -1157,8 +1158,6 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}" encoded_by_us = tmp_path / f"our_output.{output_format}" - # Note: output format may be different from ours, e.g. FFmpeg CLI would - # choose s32 while our current heuristic may choose s16. subprocess.run( ["ffmpeg", "-i", str(asset.path)] + (["-b:a", f"{bit_rate}"] if bit_rate is not None else []) @@ -1177,11 +1176,12 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): ) encode_audio(encoder) + rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) torch.testing.assert_close( self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us), - rtol=0, - atol=1e-4, + rtol=rtol, + atol=atol, ) From 67dba5ac203ec1f0633800dda4306899d8b70c00 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 14 Apr 2025 13:26:34 +0100 Subject: [PATCH 14/14] Address comments --- src/torchcodec/_core/Encoder.cpp | 76 ++++++++++++++++---------------- src/torchcodec/_core/Encoder.h | 1 - 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index bcae2017..08af3402 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -33,6 +33,44 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) { supportedRates.str()); } +static const std::vector preferredFormatsOrder = { + AV_SAMPLE_FMT_FLTP, + AV_SAMPLE_FMT_FLT, + AV_SAMPLE_FMT_DBLP, + AV_SAMPLE_FMT_DBL, + AV_SAMPLE_FMT_S64P, + AV_SAMPLE_FMT_S64, + AV_SAMPLE_FMT_S32P, + AV_SAMPLE_FMT_S32, + AV_SAMPLE_FMT_S16P, + AV_SAMPLE_FMT_S16, + AV_SAMPLE_FMT_U8P, + AV_SAMPLE_FMT_U8}; + +AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) { + // Find a sample format that the encoder supports. We prefer using FLT[P], + // since this is the format of the input waveform. If FLTP isn't supported + // then we'll need to convert the AVFrame's format. Our heuristic is to encode + // into the format with the highest resolution. + if (avCodec.sample_fmts == nullptr) { + // Can't really validate anything in this case, best we can do is hope that + // FLTP is supported by the encoder. If not, FFmpeg will raise. + return AV_SAMPLE_FMT_FLTP; + } + + for (AVSampleFormat preferredFormat : preferredFormatsOrder) { + for (int i = 0; avCodec.sample_fmts[i] != -1; ++i) { + if (avCodec.sample_fmts[i] == preferredFormat) { + return preferredFormat; + } + } + } + // We should always find a match in preferredFormatsOrder, so we should always + // return earlier. But in the event that a future FFmpeg version defines an + // additional sample format that isn't in preferredFormatsOrder, we fallback: + return avCodec.sample_fmts[0]; +} + } // namespace AudioEncoder::~AudioEncoder() {} @@ -97,7 +135,7 @@ AudioEncoder::AudioEncoder( // Input waveform is expected to be FLTP. Not all encoders support FLTP, so we // may need to convert the wf into a supported output sample format, which is // what the `.sample_fmt` defines. - avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec); + avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); int numChannels = static_cast(wf_.sizes()[0]); TORCH_CHECK( @@ -132,42 +170,6 @@ AudioEncoder::AudioEncoder( streamIndex_ = avStream->index; } -AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) { - // Find a sample format that the encoder supports. We prefer using FLT[P], - // since this is the format of the input waveform. If FLTP isn't supported - // then we'll need to convert the AVFrame's format. Our heuristic is to encode - // into the format with the highest resolution. - if (avCodec.sample_fmts == nullptr) { - // Can't really validate anything in this case, best we can do is hope that - // FLTP is supported by the encoder. If not, FFmpeg will raise. - return AV_SAMPLE_FMT_FLTP; - } - - std::vector preferredFormatsOrder = { - AV_SAMPLE_FMT_FLTP, - AV_SAMPLE_FMT_FLT, - AV_SAMPLE_FMT_DBLP, - AV_SAMPLE_FMT_DBL, - AV_SAMPLE_FMT_S64P, - AV_SAMPLE_FMT_S64, - AV_SAMPLE_FMT_S32P, - AV_SAMPLE_FMT_S32, - AV_SAMPLE_FMT_S16P, - AV_SAMPLE_FMT_S16, - AV_SAMPLE_FMT_U8P, - AV_SAMPLE_FMT_U8}; - - for (AVSampleFormat preferredFormat : preferredFormatsOrder) { - for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) { - if (avCodec.sample_fmts[i] == preferredFormat) { - return preferredFormat; - } - } - } - // Should never happen, but just in case - return avCodec.sample_fmts[0]; -} - void AudioEncoder::encode() { UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 7c94ce2a..3e1abeac 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -28,7 +28,6 @@ class AudioEncoder { AutoAVPacket& autoAVPacket, const UniqueAVFrame& srcAVFrame); void flushBuffers(); - AVSampleFormat findOutputSampleFormat(const AVCodec& avCodec); UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_;