Skip to content

Commit 3c273cd

Browse files
authored
Merge pull request #373 from ev-br/skip_xfails
ENH: add ARRAY_API_TESTS_XFAIL_MARK to turn xfails into skips
2 parents c716de1 + 0da7010 commit 3c273cd

File tree

4 files changed

+41
-4
lines changed

4 files changed

+41
-4
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,26 @@ ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 pytest array_api_tests/
305305
Note that skipping certain essential dtypes such as `bool` and the default
306306
floating-point dtype is not supported.
307307
308+
#### Turning xfails into skips
309+
310+
Keeping a large number of ``xfails`` can have drastic effects on the run time. This is due
311+
to the way `hypothesis` works: when it detects a failure, it does a large amount
312+
of work to simplify the failing example.
313+
If the run time of the test suite becomes a problem, you can use the
314+
``ARRAY_API_TESTS_XFAIL_MARK`` environment variable: setting it to ``skip`` skips the
315+
entries from the ``xfail.txt`` file instead of xfailing them. Anecdotally, we saw
316+
speed-ups by a factor of 4-5---which allowed us to use 4-5 larger values of
317+
``--max-examples`` within the same time budget.
318+
319+
#### Limiting the array sizes
320+
321+
The test suite generates random arrays as inputs to functions it tests. "unvectorized"
322+
tests iterate over elements of arrays, which might be slow. If the run time becomes
323+
a problem, you can limit the maximum number of elements in generated arrays by
324+
setting the environment variable ``ARRAY_API_TESTS_MAX_ARRAY_SIZE`` to the
325+
desired value. By default, it is set to 1024.
326+
327+
308328
## Contributing
309329
310330
### Remain in-scope

array_api_tests/hypothesis_helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
import re
45
from contextlib import contextmanager
56
from functools import wraps
@@ -232,7 +233,7 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
232233
lambda i: getattr(xp, i))
233234

234235
# Limit the total size of an array shape
235-
MAX_ARRAY_SIZE = 10000
236+
MAX_ARRAY_SIZE = int(os.environ.get("ARRAY_API_TESTS_MAX_ARRAY_SIZE", 1024))
236237
# Size to use for 2-dim arrays
237238
SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE))
238239

array_api_tests/test_creation_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def test_meshgrid(dtype, data):
499499
shapes = data.draw(
500500
st.integers(1, 5).flatmap(
501501
lambda n: hh.mutually_broadcastable_shapes(
502-
n, min_dims=1, max_dims=1, max_side=5
502+
n, min_dims=1, max_dims=1, max_side=4
503503
)
504504
),
505505
label="shapes",
@@ -509,7 +509,7 @@ def test_meshgrid(dtype, data):
509509
x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}")
510510
arrays.append(x)
511511
# sanity check
512-
assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
512+
# assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
513513
out = xp.meshgrid(*arrays)
514514
for i, x in enumerate(out):
515515
ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype")

conftest.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,20 @@ def check_id_match(id_, pattern):
144144
return False
145145

146146

147+
def get_xfail_mark():
148+
"""Skip or xfail tests from the xfails-file.txt."""
149+
m = os.environ.get("ARRAY_API_TESTS_XFAIL_MARK", "xfail")
150+
if m == "xfail":
151+
return mark.xfail
152+
elif m == "skip":
153+
return mark.skip
154+
else:
155+
raise ValueError(
156+
f'ARRAY_API_TESTS_XFAIL_MARK value should be one of "skip" or "xfail" '
157+
f'got {m} instead.'
158+
)
159+
160+
147161
def pytest_collection_modifyitems(config, items):
148162
# 1. Prepare for iterating over items
149163
# -----------------------------------
@@ -187,6 +201,8 @@ def pytest_collection_modifyitems(config, items):
187201
# 2. Iterate through items and apply markers accordingly
188202
# ------------------------------------------------------
189203

204+
xfail_mark = get_xfail_mark()
205+
190206
for item in items:
191207
markers = list(item.iter_markers())
192208
# skip if specified in skips file
@@ -198,7 +214,7 @@ def pytest_collection_modifyitems(config, items):
198214
# xfail if specified in xfails file
199215
for id_ in xfail_ids:
200216
if check_id_match(item.nodeid, id_):
201-
item.add_marker(mark.xfail(reason=f"--xfails-file ({xfails_file})"))
217+
item.add_marker(xfail_mark(reason=f"--xfails-file ({xfails_file})"))
202218
xfail_id_matched[id_] = True
203219
break
204220
# skip if disabled or non-existent extension

0 commit comments

Comments
 (0)