Skip to content

Commit ceeeb11

Browse files
committed
Use psnr to compare frames
This commit brings dependency from torcheval in tests. Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent b4619d7 commit ceeeb11

File tree

4 files changed

+23
-26
lines changed

4 files changed

+23
-26
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dev = [
2525
"numpy",
2626
"pytest",
2727
"pillow",
28+
"torcheval",
2829
]
2930

3031
[tool.usort]

test/test_ops.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,13 @@ def test_get_frame_at_pts(self, device):
105105
frame6, _, _ = get_frame_at_pts(decoder, 6.02)
106106
assert_frames_equal(frame6, reference_frame6.to(device))
107107
frame6, _, _ = get_frame_at_pts(decoder, 6.039366)
108-
assert_frames_equal(frame6, reference_frame6.to(device))
108+
prev_frame_psnr = assert_frames_equal(frame6, reference_frame6.to(device))
109109
# Note that this timestamp is exactly on a frame boundary, so it should
110110
# return the next frame since the right boundary of the interval is
111111
# open.
112112
next_frame, _, _ = get_frame_at_pts(decoder, 6.039367)
113-
if device == "cpu":
114-
# We can only compare exact equality on CPU.
115-
with pytest.raises(AssertionError):
116-
assert_frames_equal(next_frame, reference_frame6.to(device))
113+
with pytest.raises(AssertionError):
114+
assert_frames_equal(next_frame, reference_frame6.to(device), psnr=prev_frame_psnr)
117115

118116
@pytest.mark.parametrize("device", cpu_and_cuda())
119117
def test_get_frame_at_index(self, device):

test/test_samplers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,11 @@ def test_sampling_range(
250250
cm = (
251251
contextlib.nullcontext()
252252
if assert_all_equal
253-
else pytest.raises(AssertionError, match="Tensor-likes are not")
253+
else pytest.raises(AssertionError, match="low psnr")
254254
)
255255
with cm:
256256
for clip in clips:
257-
assert_frames_equal(clip.data, clips[0].data)
257+
assert_frames_equal(clip.data, clips[0].data, psnr=float("inf"))
258258

259259

260260
@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices))
@@ -447,7 +447,7 @@ def test_random_sampler_randomness(sampler):
447447
# Call with a different seed, expect different results
448448
torch.manual_seed(1)
449449
clips_3 = sampler(decoder, num_clips=num_clips)
450-
with pytest.raises(AssertionError, match="Tensor-likes are not"):
450+
with pytest.raises(AssertionError, match="low psnr"):
451451
assert_frames_equal(clips_1[0].data, clips_3[0].data)
452452

453453
# Make sure we didn't alter the builtin Python RNG

test/utils.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,23 @@ def get_ffmpeg_major_version():
3030
return int(get_ffmpeg_library_versions()["ffmpeg_version"].split(".")[0])
3131

3232

33-
# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit
34-
# equality. On CUDA Linux, we expect a small tolerance.
35-
# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does
36-
# not guarantee bit-for-bit equality across systems and architectures, so we
37-
# also cannot. We currently use Linux on x86_64 as our reference system.
38-
def assert_frames_equal(*args, **kwargs):
39-
if sys.platform == "linux":
40-
if args[0].device.type == "cuda":
41-
atol = 2
42-
if get_ffmpeg_major_version() == 4:
43-
assert_tensor_close_on_at_least(
44-
args[0], args[1], percentage=95, atol=atol
45-
)
46-
else:
47-
torch.testing.assert_close(*args, **kwargs, atol=atol, rtol=0)
48-
else:
49-
torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0)
33+
# For use with decoded data frames. `psnr` sets the PSNR threshold when
34+
# frames are considered equal. `float("inf")` correspond to bit-to-bit
35+
# identical frames. Function returns calculated psnr value.
36+
def assert_frames_equal(input, other, psnr=40, msg=None):
37+
if torch.allclose(input, other, atol=0, rtol=0):
38+
return float("inf")
5039
else:
51-
torch.testing.assert_close(*args, **kwargs, atol=3, rtol=0)
40+
from torcheval.metrics import PeakSignalNoiseRatio
41+
42+
metric = PeakSignalNoiseRatio()
43+
metric.update(input, other)
44+
m = metric.compute()
45+
message = f"low psnr: {m} < {psnr}"
46+
if (msg):
47+
message += f" ({msg})"
48+
assert m >= psnr, message
49+
return m
5250

5351

5452
# Asserts that at least `percentage`% of the values are within the absolute tolerance.

0 commit comments

Comments
 (0)