@@ -724,7 +724,9 @@ def _assert_correctness_binary(
724
724
x1a , x2a = in_arrs
725
725
ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype , expected = expected_dtype )
726
726
ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
727
- binary_assert_against_refimpl (name , x1a , x2a , out , func , ** kwargs )
727
+ check_values = kwargs .pop ('check_values' , None )
728
+ if check_values :
729
+ binary_assert_against_refimpl (name , x1a , x2a , out , func , ** kwargs )
728
730
729
731
730
732
@pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , dh .numeric_dtypes ))
@@ -1845,6 +1847,7 @@ def _filter_zero(x):
1845
1847
("less_equal" , operator .le , {}, xp .bool ),
1846
1848
("greater" , operator .gt , {}, xp .bool ),
1847
1849
("greater_equal" , operator .ge , {}, xp .bool ),
1850
+ ("pow" , operator .pow , {'check_values' : False }, None ) # value tests are too finicky for pow
1848
1851
],
1849
1852
ids = lambda func_data : func_data [0 ] # use names for test IDs
1850
1853
)
@@ -1902,6 +1905,23 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
1902
1905
_check_binary_with_scalars ((func_name , refimpl_ , kwargs , expected ), x1x2 )
1903
1906
1904
1907
1908
+ @pytest .mark .min_version ("2024.12" )
1909
+ @pytest .mark .parametrize ('func_data' ,
1910
+ # func_name, refimpl, kwargs, expected_dtype
1911
+ [
1912
+ ("bitwise_left_shift" , operator .lshift , {}, None ),
1913
+ ("bitwise_right_shift" , operator .rshift , {}, None ),
1914
+ ],
1915
+ ids = lambda func_data : func_data [0 ] # use names for test IDs
1916
+ )
1917
+ @given (x1x2 = hh .array_and_py_scalar ([xp .int32 ], positive = True , mM = (1 , 3 )))
1918
+ def test_binary_with_scalars_bitwise_shifts (func_data , x1x2 ):
1919
+ func_name , refimpl , kwargs , expected = func_data
1920
+ # repack the refimpl
1921
+ refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
1922
+ _check_binary_with_scalars ((func_name , refimpl_ , kwargs , expected ), x1x2 )
1923
+
1924
+
1905
1925
@pytest .mark .unvectorized
1906
1926
@given (
1907
1927
x1x2 = hh .array_and_py_scalar ([xp .int32 ]),
@@ -1931,3 +1951,4 @@ def test_where_with_scalars(x1x2, data):
1931
1951
else :
1932
1952
assert out [idx ] == x2_arr [idx ]
1933
1953
1954
+
0 commit comments