diff --git a/test/__init__.py b/test/__init__.py index d3a63db2d5..b49eee99ac 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import inspect import logging import multiprocessing import os @@ -30,28 +31,6 @@ import unittest import warnings from asyncio import iscoroutinefunction -from test.helpers import ( - COMPRESSORS, - IS_SRV, - MONGODB_API_VERSION, - MULTI_MONGOS_LB_URI, - TEST_LOADBALANCER, - TEST_SERVERLESS, - TLS_OPTIONS, - SystemCertsPatcher, - client_knobs, - db_pwd, - db_user, - global_knobs, - host, - is_server_resolvable, - port, - print_running_topology, - print_thread_stacks, - print_thread_tracebacks, - sanitize_cmd, - sanitize_reply, -) from pymongo.uri_parser import parse_uri @@ -63,7 +42,6 @@ HAVE_IPADDRESS = False from contextlib import contextmanager from functools import partial, wraps -from test.version import Version from typing import Any, Callable, Dict, Generator, overload from unittest import SkipTest from urllib.parse import quote_plus @@ -78,6 +56,32 @@ from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient +sys.path[0:0] = [""] + +from test.helpers import ( + COMPRESSORS, + IS_SRV, + MONGODB_API_VERSION, + MULTI_MONGOS_LB_URI, + TEST_LOADBALANCER, + TEST_SERVERLESS, + TLS_OPTIONS, + SystemCertsPatcher, + client_knobs, + db_pwd, + db_user, + global_knobs, + host, + is_server_resolvable, + port, + print_running_topology, + print_thread_stacks, + print_thread_tracebacks, + sanitize_cmd, + sanitize_reply, +) +from test.version import Version + _IS_SYNC = True @@ -863,18 +867,66 @@ def max_message_size_bytes(self): # Reusable client context client_context = ClientContext() +# Global event loop for async tests. +LOOP = None -def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif client_context.client is not None: - client_context.client.close() - client_context.client = None - client_context._init_client() + +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP class PyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # An async TestCase that uses a single event loop for all tests. + # Inspired by TestCase. + def setUp(self): + pass + + def tearDown(self): + pass + + def addCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def _callSetUp(self): + self.setUp() + self._callAsync(self.setUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.tearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return get_loop().run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return get_loop().run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) + def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase): @client_context.require_connection def setUp(self) -> None: - if not _IS_SYNC: - reset_client_context() if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): @@ -1186,6 +1236,9 @@ def tearDown(self) -> None: def setup(): + if not _IS_SYNC: + # Set up the event loop. + get_loop() client_context.init() warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 73e2824742..76fae407da 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import inspect import logging import multiprocessing import os @@ -30,28 +31,6 @@ import unittest import warnings from asyncio import iscoroutinefunction -from test.helpers import ( - COMPRESSORS, - IS_SRV, - MONGODB_API_VERSION, - MULTI_MONGOS_LB_URI, - TEST_LOADBALANCER, - TEST_SERVERLESS, - TLS_OPTIONS, - SystemCertsPatcher, - client_knobs, - db_pwd, - db_user, - global_knobs, - host, - is_server_resolvable, - port, - print_running_topology, - print_thread_stacks, - print_thread_tracebacks, - sanitize_cmd, - sanitize_reply, -) from pymongo.uri_parser import parse_uri @@ -63,7 +42,6 @@ HAVE_IPADDRESS = False from contextlib import asynccontextmanager, contextmanager from functools import partial, wraps -from test.version import Version from typing import Any, Callable, Dict, Generator, overload from unittest import SkipTest from urllib.parse import quote_plus @@ -78,6 +56,32 @@ from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] +sys.path[0:0] = [""] + +from test.helpers import ( + COMPRESSORS, + IS_SRV, + MONGODB_API_VERSION, + MULTI_MONGOS_LB_URI, + TEST_LOADBALANCER, + TEST_SERVERLESS, + TLS_OPTIONS, + SystemCertsPatcher, + client_knobs, + db_pwd, + db_user, + global_knobs, + host, + is_server_resolvable, + port, + print_running_topology, + print_thread_stacks, + print_thread_tracebacks, + sanitize_cmd, + sanitize_reply, +) +from test.version import Version + _IS_SYNC = False @@ -865,18 +869,66 @@ async def max_message_size_bytes(self): # Reusable client context async_client_context = AsyncClientContext() +# Global event loop for async tests. +LOOP = None + + +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP + + +class AsyncPyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # An async TestCase that uses a single event loop for all tests. + # Inspired by IsolatedAsyncioTestCase. + async def asyncSetUp(self): + pass -async def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif async_client_context.client is not None: - await async_client_context.client.close() - async_client_context.client = None - await async_client_context._init_client() + async def asyncTearDown(self): + pass + def addAsyncCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def _callSetUp(self): + self.setUp() + self._callAsync(self.asyncSetUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.asyncTearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return get_loop().run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return get_loop().run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) -class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): @async_client_context.require_connection async def asyncSetUp(self) -> None: - if not _IS_SYNC: - await reset_client_context() if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): @@ -1204,6 +1254,9 @@ async def asyncTearDown(self) -> None: async def async_setup(): + if not _IS_SYNC: + # Set up the event loop. + get_loop() await async_client_context.init() warnings.resetwarnings() warnings.simplefilter("always")