Skip to content

Commit d717d6c

Browse files
authored
Ensure correct encoding for non-contiguous WF (#666)
1 parent 83f3e79 commit d717d6c

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

src/torchcodec/_core/Encoder.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@ torch::Tensor validateWf(torch::Tensor wf) {
1313
wf.dtype() == torch::kFloat32,
1414
"waveform must have float32 dtype, got ",
1515
wf.dtype());
16-
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
17-
// planar (fltp).
1816
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
19-
return wf;
17+
return wf.contiguous();
2018
}
2119

2220
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {

test/test_ops.py

+33
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,39 @@ def test_encode_to_tensor_long_output(self):
12671267

12681268
torch.testing.assert_close(self.decode(encoded_tensor), samples)
12691269

1270+
def test_contiguity(self):
1271+
# Ensure that 2 waveforms with the same values are encoded in the same
1272+
# way, regardless of their memory layout. Here we encode 2 equal
1273+
# waveforms, one is row-aligned while the other is column-aligned.
1274+
1275+
num_samples = 10_000 # per channel
1276+
contiguous_samples = torch.rand(2, num_samples).contiguous()
1277+
assert contiguous_samples.stride() == (num_samples, 1)
1278+
1279+
encoded_from_contiguous = encode_audio_to_tensor(
1280+
wf=contiguous_samples,
1281+
sample_rate=16_000,
1282+
format="flac",
1283+
bit_rate=44_000,
1284+
)
1285+
non_contiguous_samples = contiguous_samples.T.contiguous().T
1286+
assert non_contiguous_samples.stride() == (1, 2)
1287+
1288+
torch.testing.assert_close(
1289+
contiguous_samples, non_contiguous_samples, rtol=0, atol=0
1290+
)
1291+
1292+
encoded_from_non_contiguous = encode_audio_to_tensor(
1293+
wf=non_contiguous_samples,
1294+
sample_rate=16_000,
1295+
format="flac",
1296+
bit_rate=44_000,
1297+
)
1298+
1299+
torch.testing.assert_close(
1300+
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
1301+
)
1302+
12701303

12711304
if __name__ == "__main__":
12721305
pytest.main()

0 commit comments

Comments
 (0)