Skip to content

Commit 2e33674

Browse files
committed
ENH: more binary functions with arrays and python scalars
- equal, not_equal, greater, greater_equal, less, less_equal - add, subtract, multiply, divide
1 parent 3650a6a commit 2e33674

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from hypothesis import assume, reject
1111
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
12-
integers, just, lists, none, one_of,
12+
integers, complex_numbers, just, lists, none, one_of,
1313
sampled_from, shared, builds, nothing)
1414

1515
from . import _array_module as xp, api_version
@@ -19,7 +19,7 @@
1919
from . import xps
2020
from ._array_module import _UndefinedStub
2121
from ._array_module import bool as bool_dtype
22-
from ._array_module import broadcast_to, eye, float32, float64, full
22+
from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128
2323
from .stubs import category_to_funcs
2424
from .pytest_helpers import nargs
2525
from .typing import Array, DataType, Scalar, Shape
@@ -462,6 +462,14 @@ def scalars(draw, dtypes, finite=False):
462462
if finite:
463463
return draw(floats(width=32, allow_nan=False, allow_infinity=False))
464464
return draw(floats(width=32))
465+
elif dtype == complex64:
466+
if finite:
467+
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
468+
return draw(complex_numbers(width=32))
469+
elif dtype == complex128:
470+
if finite:
471+
return draw(complex_numbers(allow_nan=False, allow_infinity=False))
472+
return draw(complex_numbers())
465473
else:
466474
raise ValueError(f"Unrecognized dtype {dtype}")
467475

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,6 +1794,8 @@ def _check_binary_with_scalars(func_data, x1x2):
17941794
# xp_func, name, refimpl, kwargs
17951795
[
17961796
(xp.atan2, "atan2", math.atan2, {}),
1797+
(xp.copysign, "copysign", math.copysign, {}),
1798+
(xp.divide, "divide", operator.truediv, {"filter_": lambda s: s != 0}),
17971799
(xp.hypot, "hypot", math.hypot, {}),
17981800
(xp.logaddexp, "logaddexp", logaddexp_refimpl, {}),
17991801
(xp.maximum, "maximum", max, {'strict_check': True}),
@@ -1813,9 +1815,33 @@ def test_binary_with_scalars_real(func_data, x1x2):
18131815
(xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}),
18141816
(xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}),
18151817
(xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}),
1818+
(xp.equal, "equal", operator.eq, {}),
1819+
(xp.not_equal, "neq", operator.ne, {}),
1820+
(xp.less, "less", operator.lt, {}),
1821+
(xp.less_equal, "les_equal", operator.le, {}),
1822+
(xp.greater, "greater", operator.gt, {}),
1823+
(xp.greater_equal, "greater_equal", operator.ge, {}),
18161824
],
18171825
ids=lambda func_data: func_data[1] # use names for test IDs
18181826
)
18191827
@given(x1x2=hh.array_and_py_scalar([xp.bool]))
18201828
def test_binary_with_scalars_bool(func_data, x1x2):
18211829
_check_binary_with_scalars(func_data, x1x2)
1830+
1831+
1832+
1833+
@pytest.mark.min_version("2024.12")
1834+
@pytest.mark.parametrize('func_data',
1835+
# xp_func, name, refimpl, kwargs
1836+
[
1837+
(xp.add, "add", operator.add, {}),
1838+
(xp.subtract, "sub", operator.sub, {}),
1839+
(xp.multiply, "mul", operator.mul, {}),
1840+
# divide is in the "real" listing to avoid int/int -> float
1841+
],
1842+
ids=lambda func_data: func_data[1] # use names for test IDs
1843+
)
1844+
@given(x1x2=hh.array_and_py_scalar(dh.numeric_dtypes))
1845+
def test_binary_with_scalars_numeric(func_data, x1x2):
1846+
_check_binary_with_scalars(func_data, x1x2)
1847+

0 commit comments

Comments
 (0)