Skip to content

Commit 27fdbac

Browse files
committed
Better default heuristic
1 parent 485ee2e commit 27fdbac

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,25 +133,38 @@ AudioEncoder::AudioEncoder(
133133
}
134134

135135
AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
136-
// Find a sample format that the encoder supports. If FLTP is supported then
137-
// we use that, since this is the expected format of the input waveform.
138-
// Otherwise, we'll need to convert the waveform before passing it to the
139-
// encoder. Right now, the output format we'll choose is just the first format
140-
// in the `sample_fmts` list that the AVCodec defines. Eventually, we may
141-
// allow the user to choose.
142-
// TODO-ENCODING: a better default would probably be to choose the highest
143-
// available precision
136+
// Find a sample format that the encoder supports. We prefer using FLT[P],
137+
// since this is the format of the input waveform. If FLTP isn't supported
138+
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
139+
// into the format with the highest resolution.
144140
if (avCodec.sample_fmts == nullptr) {
145141
// Can't really validate anything in this case, best we can do is hope that
146142
// FLTP is supported by the encoder. If not, FFmpeg will raise.
147143
return AV_SAMPLE_FMT_FLTP;
148144
}
149145

150-
for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) {
151-
if (avCodec.sample_fmts[i] == AV_SAMPLE_FMT_FLTP) {
152-
return AV_SAMPLE_FMT_FLTP;
146+
std::vector<AVSampleFormat> preferredFormatsOrder = {
147+
AV_SAMPLE_FMT_FLTP,
148+
AV_SAMPLE_FMT_FLT,
149+
AV_SAMPLE_FMT_DBLP,
150+
AV_SAMPLE_FMT_DBL,
151+
AV_SAMPLE_FMT_S64P,
152+
AV_SAMPLE_FMT_S64,
153+
AV_SAMPLE_FMT_S32P,
154+
AV_SAMPLE_FMT_S32,
155+
AV_SAMPLE_FMT_S16P,
156+
AV_SAMPLE_FMT_S16,
157+
AV_SAMPLE_FMT_U8P,
158+
AV_SAMPLE_FMT_U8};
159+
160+
for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
161+
for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) {
162+
if (avCodec.sample_fmts[i] == preferredFormat) {
163+
return preferredFormat;
164+
}
153165
}
154166
}
167+
// Should never happen, but just in case
155168
return avCodec.sample_fmts[0];
156169
}
157170

test/test_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,8 +1139,9 @@ def test_round_trip(self, output_format, tmp_path):
11391139
)
11401140
encode_audio(encoder)
11411141

1142+
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None)
11421143
torch.testing.assert_close(
1143-
self.decode(encoded_path), source_samples, rtol=0, atol=1e-4
1144+
self.decode(encoded_path), source_samples, rtol=rtol, atol=atol
11441145
)
11451146

11461147
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
@@ -1157,8 +1158,6 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path):
11571158
encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}"
11581159
encoded_by_us = tmp_path / f"our_output.{output_format}"
11591160

1160-
# Note: output format may be different from ours, e.g. FFmpeg CLI would
1161-
# choose s32 while our current heuristic may choose s16.
11621161
subprocess.run(
11631162
["ffmpeg", "-i", str(asset.path)]
11641163
+ (["-b:a", f"{bit_rate}"] if bit_rate is not None else [])
@@ -1177,11 +1176,12 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path):
11771176
)
11781177
encode_audio(encoder)
11791178

1179+
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None)
11801180
torch.testing.assert_close(
11811181
self.decode(encoded_by_ffmpeg),
11821182
self.decode(encoded_by_us),
1183-
rtol=0,
1184-
atol=1e-4,
1183+
rtol=rtol,
1184+
atol=atol,
11851185
)
11861186

11871187

0 commit comments

Comments
 (0)