|
| 1 | +import os |
1 | 2 | import random
|
2 | 3 |
|
3 | 4 | import pytest
|
4 | 5 | import torch
|
5 | 6 |
|
6 | 7 |
|
| 8 | +def pytest_configure(config): |
| 9 | + # register an additional marker (see pytest_collection_modifyitems) |
| 10 | + config.addinivalue_line( |
| 11 | + "markers", "needs_cuda: mark for tests that rely on a CUDA device" |
| 12 | + ) |
| 13 | + |
| 14 | + |
| 15 | +def pytest_collection_modifyitems(items): |
| 16 | + # This hook is called by pytest after it has collected the tests (google its |
| 17 | + # name to check out its doc!). We can ignore some tests as we see fit here, |
| 18 | + # or add marks, such as a skip mark. |
| 19 | + |
| 20 | + out_items = [] |
| 21 | + for item in items: |
| 22 | + # The needs_cuda mark will exist if the test was explicitly decorated |
| 23 | + # with the @needs_cuda decorator. It will also exist if it was |
| 24 | + # parametrized with a parameter that has the mark: for example if a test |
| 25 | + # is parametrized with |
| 26 | + # @pytest.mark.parametrize('device', cpu_and_cuda()) |
| 27 | + # the "instances" of the tests where device == 'cuda' will have the |
| 28 | + # 'needs_cuda' mark, and the ones with device == 'cpu' won't have the |
| 29 | + # mark. |
| 30 | + needs_cuda = item.get_closest_marker("needs_cuda") is not None |
| 31 | + |
| 32 | + if ( |
| 33 | + needs_cuda |
| 34 | + and not torch.cuda.is_available() |
| 35 | + and os.environ.get("FAIL_WITHOUT_CUDA") is None |
| 36 | + ): |
| 37 | + # We skip CUDA tests on non-CUDA machines, but only if the |
| 38 | + # FAIL_WITHOUT_CUDA env var wasn't set. If it's set, the test will |
| 39 | + # typically fail with a "Unsupported device: cuda" error. This is |
| 40 | + # normal and desirable: this env var is set on CI jobs that are |
| 41 | + # supposed to run the CUDA tests, so if CUDA isn't available on |
| 42 | + # those for whatever reason, we need to know. |
| 43 | + item.add_marker(pytest.mark.skip(reason="CUDA not available.")) |
| 44 | + |
| 45 | + out_items.append(item) |
| 46 | + |
| 47 | + items[:] = out_items |
| 48 | + |
| 49 | + |
7 | 50 | @pytest.fixture(autouse=True)
|
8 | 51 | def prevent_leaking_rng():
|
9 | 52 | # Prevent each test from leaking the rng to all other test when they call
|
|
0 commit comments