Skip to content

Commit 76900f6

Browse files
Add more cuda tests (#326)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
1 parent aa9b872 commit 76900f6

File tree

3 files changed

+156
-120
lines changed

3 files changed

+156
-120
lines changed

test/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,52 @@
1+
import os
12
import random
23

34
import pytest
45
import torch
56

67

8+
def pytest_configure(config):
9+
# register an additional marker (see pytest_collection_modifyitems)
10+
config.addinivalue_line(
11+
"markers", "needs_cuda: mark for tests that rely on a CUDA device"
12+
)
13+
14+
15+
def pytest_collection_modifyitems(items):
16+
# This hook is called by pytest after it has collected the tests (google its
17+
# name to check out its doc!). We can ignore some tests as we see fit here,
18+
# or add marks, such as a skip mark.
19+
20+
out_items = []
21+
for item in items:
22+
# The needs_cuda mark will exist if the test was explicitly decorated
23+
# with the @needs_cuda decorator. It will also exist if it was
24+
# parametrized with a parameter that has the mark: for example if a test
25+
# is parametrized with
26+
# @pytest.mark.parametrize('device', cpu_and_cuda())
27+
# the "instances" of the tests where device == 'cuda' will have the
28+
# 'needs_cuda' mark, and the ones with device == 'cpu' won't have the
29+
# mark.
30+
needs_cuda = item.get_closest_marker("needs_cuda") is not None
31+
32+
if (
33+
needs_cuda
34+
and not torch.cuda.is_available()
35+
and os.environ.get("FAIL_WITHOUT_CUDA") is None
36+
):
37+
# We skip CUDA tests on non-CUDA machines, but only if the
38+
# FAIL_WITHOUT_CUDA env var wasn't set. If it's set, the test will
39+
# typically fail with a "Unsupported device: cuda" error. This is
40+
# normal and desirable: this env var is set on CI jobs that are
41+
# supposed to run the CUDA tests, so if CUDA isn't available on
42+
# those for whatever reason, we need to know.
43+
item.add_marker(pytest.mark.skip(reason="CUDA not available."))
44+
45+
out_items.append(item)
46+
47+
items[:] = out_items
48+
49+
750
@pytest.fixture(autouse=True)
851
def prevent_leaking_rng():
952
# Prevent each test from leaking the rng to all other test when they call

0 commit comments

Comments
 (0)