Skip to content

PYTHON-5071 Use one event loop for all asyncio tests #2086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 30, 2025
119 changes: 86 additions & 33 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import gc
import inspect
import logging
import multiprocessing
import os
Expand All @@ -30,28 +31,6 @@
import unittest
import warnings
from asyncio import iscoroutinefunction
from test.helpers import (
COMPRESSORS,
IS_SRV,
MONGODB_API_VERSION,
MULTI_MONGOS_LB_URI,
TEST_LOADBALANCER,
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
client_knobs,
db_pwd,
db_user,
global_knobs,
host,
is_server_resolvable,
port,
print_running_topology,
print_thread_stacks,
print_thread_tracebacks,
sanitize_cmd,
sanitize_reply,
)

from pymongo.uri_parser import parse_uri

Expand All @@ -63,7 +42,6 @@
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
Expand All @@ -78,6 +56,32 @@
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient

sys.path[0:0] = [""]

from test.helpers import (
COMPRESSORS,
IS_SRV,
MONGODB_API_VERSION,
MULTI_MONGOS_LB_URI,
TEST_LOADBALANCER,
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
client_knobs,
db_pwd,
db_user,
global_knobs,
host,
is_server_resolvable,
port,
print_running_topology,
print_thread_stacks,
print_thread_tracebacks,
sanitize_cmd,
sanitize_reply,
)
from test.version import Version

_IS_SYNC = True


Expand Down Expand Up @@ -863,18 +867,66 @@ def max_message_size_bytes(self):
# Reusable client context
client_context = ClientContext()

# Global event loop for async tests.
LOOP = None

def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif client_context.client is not None:
client_context.client.close()
client_context.client = None
client_context._init_client()

def get_loop() -> asyncio.AbstractEventLoop:
"""Get the test suite's global event loop."""
global LOOP
if LOOP is None:
try:
LOOP = asyncio.get_running_loop()
except RuntimeError:
# no running event loop, fallback to get_event_loop.
try:
# Ignore DeprecationWarning: There is no current event loop
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
LOOP = asyncio.get_event_loop()
except RuntimeError:
LOOP = asyncio.new_event_loop()
asyncio.set_event_loop(LOOP)
return LOOP


class PyMongoTestCase(unittest.TestCase):
if not _IS_SYNC:
# An async TestCase that uses a single event loop for all tests.
# Inspired by TestCase.
def setUp(self):
pass

def tearDown(self):
pass

def addCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)

def _callSetUp(self):
self.setUp()
self._callAsync(self.setUp)

def _callTestMethod(self, method):
self._callMaybeAsync(method)

def _callTearDown(self):
self._callAsync(self.tearDown)
self.tearDown()

def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)

def _callAsync(self, func, /, *args, **kwargs):
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
return get_loop().run_until_complete(func(*args, **kwargs))

def _callMaybeAsync(self, func, /, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return get_loop().run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)

def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)

Expand Down Expand Up @@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase):

@client_context.require_connection
def setUp(self) -> None:
if not _IS_SYNC:
reset_client_context()
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
Expand Down Expand Up @@ -1186,6 +1236,9 @@ def tearDown(self) -> None:


def setup():
if not _IS_SYNC:
# Set up the event loop.
get_loop()
client_context.init()
warnings.resetwarnings()
warnings.simplefilter("always")
Expand Down
121 changes: 87 additions & 34 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import gc
import inspect
import logging
import multiprocessing
import os
Expand All @@ -30,28 +31,6 @@
import unittest
import warnings
from asyncio import iscoroutinefunction
from test.helpers import (
COMPRESSORS,
IS_SRV,
MONGODB_API_VERSION,
MULTI_MONGOS_LB_URI,
TEST_LOADBALANCER,
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
client_knobs,
db_pwd,
db_user,
global_knobs,
host,
is_server_resolvable,
port,
print_running_topology,
print_thread_stacks,
print_thread_tracebacks,
sanitize_cmd,
sanitize_reply,
)

from pymongo.uri_parser import parse_uri

Expand All @@ -63,7 +42,6 @@
HAVE_IPADDRESS = False
from contextlib import asynccontextmanager, contextmanager
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
Expand All @@ -78,6 +56,32 @@
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]

sys.path[0:0] = [""]

from test.helpers import (
COMPRESSORS,
IS_SRV,
MONGODB_API_VERSION,
MULTI_MONGOS_LB_URI,
TEST_LOADBALANCER,
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
client_knobs,
db_pwd,
db_user,
global_knobs,
host,
is_server_resolvable,
port,
print_running_topology,
print_thread_stacks,
print_thread_tracebacks,
sanitize_cmd,
sanitize_reply,
)
from test.version import Version

_IS_SYNC = False


Expand Down Expand Up @@ -865,18 +869,66 @@ async def max_message_size_bytes(self):
# Reusable client context
async_client_context = AsyncClientContext()

# Global event loop for async tests.
LOOP = None


def get_loop() -> asyncio.AbstractEventLoop:
"""Get the test suite's global event loop."""
global LOOP
if LOOP is None:
try:
LOOP = asyncio.get_running_loop()
except RuntimeError:
# no running event loop, fallback to get_event_loop.
try:
# Ignore DeprecationWarning: There is no current event loop
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
LOOP = asyncio.get_event_loop()
except RuntimeError:
LOOP = asyncio.new_event_loop()
asyncio.set_event_loop(LOOP)
return LOOP


class AsyncPyMongoTestCase(unittest.TestCase):
if not _IS_SYNC:
# An async TestCase that uses a single event loop for all tests.
# Inspired by IsolatedAsyncioTestCase.
async def asyncSetUp(self):
pass

async def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif async_client_context.client is not None:
await async_client_context.client.close()
async_client_context.client = None
await async_client_context._init_client()
async def asyncTearDown(self):
pass

def addAsyncCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)

def _callSetUp(self):
self.setUp()
self._callAsync(self.asyncSetUp)

def _callTestMethod(self, method):
self._callMaybeAsync(method)

def _callTearDown(self):
self._callAsync(self.asyncTearDown)
self.tearDown()

def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)

def _callAsync(self, func, /, *args, **kwargs):
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
return get_loop().run_until_complete(func(*args, **kwargs))

def _callMaybeAsync(self, func, /, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return get_loop().run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)

class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)

Expand Down Expand Up @@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):

@async_client_context.require_connection
async def asyncSetUp(self) -> None:
if not _IS_SYNC:
await reset_client_context()
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
Expand Down Expand Up @@ -1204,6 +1254,9 @@ async def asyncTearDown(self) -> None:


async def async_setup():
if not _IS_SYNC:
# Set up the event loop.
get_loop()
await async_client_context.init()
warnings.resetwarnings()
warnings.simplefilter("always")
Expand Down
Loading