|
| 1 | +#include <sstream> |
| 2 | + |
1 | 3 | #include "src/torchcodec/_core/Encoder.h"
|
2 | 4 | #include "torch/types.h"
|
3 | 5 |
|
4 | 6 | namespace facebook::torchcodec {
|
5 | 7 |
|
| 8 | +namespace { |
| 9 | + |
| 10 | +void validateSampleRate(const AVCodec& avCodec, int sampleRate) { |
| 11 | + if (avCodec.supported_samplerates == nullptr) { |
| 12 | + return; |
| 13 | + } |
| 14 | + |
| 15 | + for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) { |
| 16 | + if (sampleRate == avCodec.supported_samplerates[i]) { |
| 17 | + return; |
| 18 | + } |
| 19 | + } |
| 20 | + std::stringstream supportedRates; |
| 21 | + for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) { |
| 22 | + if (i > 0) { |
| 23 | + supportedRates << ", "; |
| 24 | + } |
| 25 | + supportedRates << avCodec.supported_samplerates[i]; |
| 26 | + } |
| 27 | + |
| 28 | + TORCH_CHECK( |
| 29 | + false, |
| 30 | + "invalid sample rate=", |
| 31 | + sampleRate, |
| 32 | + ". Supported sample rate values are: ", |
| 33 | + supportedRates.str()); |
| 34 | +} |
| 35 | + |
| 36 | +} // namespace |
| 37 | + |
6 | 38 | AudioEncoder::~AudioEncoder() {}
|
7 | 39 |
|
8 | 40 | AudioEncoder::AudioEncoder(
|
9 | 41 | const torch::Tensor wf,
|
10 | 42 | int sampleRate,
|
11 | 43 | std::string_view fileName,
|
12 | 44 | std::optional<int64_t> bit_rate)
|
13 |
| - : wf_(wf), sampleRate_(sampleRate) { |
| 45 | + : wf_(wf) { |
14 | 46 | TORCH_CHECK(
|
15 | 47 | wf_.dtype() == torch::kFloat32,
|
16 | 48 | "waveform must have float32 dtype, got ",
|
@@ -57,7 +89,8 @@ AudioEncoder::AudioEncoder(
|
57 | 89 | // well when "-b:a" isn't specified.
|
58 | 90 | avCodecContext_->bit_rate = bit_rate.value_or(0);
|
59 | 91 |
|
60 |
| - avCodecContext_->sample_rate = sampleRate_; |
| 92 | + validateSampleRate(*avCodec, sampleRate); |
| 93 | + avCodecContext_->sample_rate = sampleRate; |
61 | 94 |
|
62 | 95 | // Note: This is the format of the **input** waveform. This doesn't determine
|
63 | 96 | // the output.
|
|
0 commit comments