Skip to content

Commit a5734c1

Browse files
authored
Merge pull request #371 from ev-br/test_pow_with_scalars
test pow() and bitwise shifts with scalars
1 parent 188b1e9 commit a5734c1

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,12 @@ def scalars(draw, dtypes, finite=False, **kwds):
456456
dtypes should be one of the shared_* dtypes strategies.
457457
"""
458458
dtype = draw(dtypes)
459+
mM = kwds.pop('mM', None)
459460
if dh.is_int_dtype(dtype):
460-
m, M = dh.dtype_ranges[dtype]
461+
if mM is None:
462+
m, M = dh.dtype_ranges[dtype]
463+
else:
464+
m, M = mM
461465
return draw(integers(m, M))
462466
elif dtype == bool_dtype:
463467
return draw(booleans())
@@ -588,18 +592,20 @@ def two_mutual_arrays(
588592

589593

590594
@composite
591-
def array_and_py_scalar(draw, dtypes):
595+
def array_and_py_scalar(draw, dtypes, mM=None, positive=False):
592596
"""Draw a pair: (array, scalar) or (scalar, array)."""
593597
dtype = draw(sampled_from(dtypes))
594598

595-
scalar_var = draw(scalars(just(dtype), finite=True,
596-
**{'min_value': 1/ (2<<5), 'max_value': 2<<5}
597-
))
599+
scalar_var = draw(scalars(just(dtype), finite=True, mM=mM))
600+
if positive:
601+
assume (scalar_var > 0)
598602

599603
elements={}
600604
if dtype in dh.real_float_dtypes:
601605
elements = {'allow_nan': False, 'allow_infinity': False,
602606
'min_value': 1.0 / (2<<5), 'max_value': 2<<5}
607+
if positive:
608+
elements = {'min_value': 0}
603609
array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements))
604610

605611
if draw(booleans()):

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,9 @@ def _assert_correctness_binary(
724724
x1a, x2a = in_arrs
725725
ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype, expected=expected_dtype)
726726
ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape)
727-
binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs)
727+
check_values = kwargs.pop('check_values', None)
728+
if check_values:
729+
binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs)
728730

729731

730732
@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes))
@@ -1845,6 +1847,7 @@ def _filter_zero(x):
18451847
("less_equal", operator.le, {}, xp.bool),
18461848
("greater", operator.gt, {}, xp.bool),
18471849
("greater_equal", operator.ge, {}, xp.bool),
1850+
("pow", operator.pow, {'check_values': False}, None) # value tests are too finicky for pow
18481851
],
18491852
ids=lambda func_data: func_data[0] # use names for test IDs
18501853
)
@@ -1902,6 +1905,23 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
19021905
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
19031906

19041907

1908+
@pytest.mark.min_version("2024.12")
1909+
@pytest.mark.parametrize('func_data',
1910+
# func_name, refimpl, kwargs, expected_dtype
1911+
[
1912+
("bitwise_left_shift", operator.lshift, {}, None),
1913+
("bitwise_right_shift", operator.rshift, {}, None),
1914+
],
1915+
ids=lambda func_data: func_data[0] # use names for test IDs
1916+
)
1917+
@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3)))
1918+
def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):
1919+
func_name, refimpl, kwargs, expected = func_data
1920+
# repack the refimpl
1921+
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
1922+
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
1923+
1924+
19051925
@pytest.mark.unvectorized
19061926
@given(
19071927
x1x2=hh.array_and_py_scalar([xp.int32]),
@@ -1931,3 +1951,4 @@ def test_where_with_scalars(x1x2, data):
19311951
else:
19321952
assert out[idx] == x2_arr[idx]
19331953

1954+

0 commit comments

Comments
 (0)