diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index b56bdecc..84e6f34c 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -24,7 +24,12 @@ def float32(n: Union[int, float]) -> float: def _float_match_complex(complex_dtype): - return xp.float32 if complex_dtype == xp.complex64 else xp.float64 + if complex_dtype == xp.complex64: + return xp.float32 + elif complex_dtype == xp.complex128: + return xp.float64 + else: + return dh.default_float @given(