Skip to content

Add trio implementation #1628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/topics/keepalive.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ measured during the last exchange of Ping and Pong frames::
Alternatively, you can measure the latency at any time by calling
:attr:`~asyncio.connection.Connection.ping` and awaiting its result::

pong_waiter = await websocket.ping()
latency = await pong_waiter
pong_received = await websocket.ping()
latency = await pong_received

Latency between a client and a server may increase for two reasons:

Expand Down
70 changes: 35 additions & 35 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def __init__(
self.close_deadline: float | None = None

# Protect sending fragmented messages.
self.fragmented_send_waiter: asyncio.Future[None] | None = None
self.send_in_progress: asyncio.Future[None] | None = None

# Mapping of ping IDs to pong waiters, in chronological order.
self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {}
self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {}

self.latency: float = 0
"""
Expand Down Expand Up @@ -468,8 +468,8 @@ async def send(
"""
# While sending a fragmented message, prevent sending other messages
# until all fragments are sent.
while self.fragmented_send_waiter is not None:
await asyncio.shield(self.fragmented_send_waiter)
while self.send_in_progress is not None:
await asyncio.shield(self.send_in_progress)

# Unfragmented message -- this case must be handled first because
# strings and bytes-like objects are iterable.
Expand Down Expand Up @@ -502,8 +502,8 @@ async def send(
except StopIteration:
return

assert self.fragmented_send_waiter is None
self.fragmented_send_waiter = self.loop.create_future()
assert self.send_in_progress is None
self.send_in_progress = self.loop.create_future()
try:
# First fragment.
if isinstance(chunk, str):
Expand Down Expand Up @@ -549,8 +549,8 @@ async def send(
raise

finally:
self.fragmented_send_waiter.set_result(None)
self.fragmented_send_waiter = None
self.send_in_progress.set_result(None)
self.send_in_progress = None

# Fragmented message -- async iterator.

Expand All @@ -561,8 +561,8 @@ async def send(
except StopAsyncIteration:
return

assert self.fragmented_send_waiter is None
self.fragmented_send_waiter = self.loop.create_future()
assert self.send_in_progress is None
self.send_in_progress = self.loop.create_future()
try:
# First fragment.
if isinstance(chunk, str):
Expand Down Expand Up @@ -610,8 +610,8 @@ async def send(
raise

finally:
self.fragmented_send_waiter.set_result(None)
self.fragmented_send_waiter = None
self.send_in_progress.set_result(None)
self.send_in_progress = None

else:
raise TypeError("data must be str, bytes, iterable, or async iterable")
Expand All @@ -635,7 +635,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
# The context manager takes care of waiting for the TCP connection
# to terminate after calling a method that sends a close frame.
async with self.send_context():
if self.fragmented_send_waiter is not None:
if self.send_in_progress is not None:
self.protocol.fail(
CloseCode.INTERNAL_ERROR,
"close during fragmented message",
Expand Down Expand Up @@ -677,9 +677,9 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]:

::

pong_waiter = await ws.ping()
pong_received = await ws.ping()
# only if you want to wait for the corresponding pong
latency = await pong_waiter
latency = await pong_received

Raises:
ConnectionClosed: When the connection is closed.
Expand All @@ -696,19 +696,19 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]:

async with self.send_context():
# Protect against duplicates if a payload is explicitly set.
if data in self.pong_waiters:
if data in self.pending_pings:
raise ConcurrencyError("already waiting for a pong with the same data")

# Generate a unique random payload otherwise.
while data is None or data in self.pong_waiters:
while data is None or data in self.pending_pings:
data = struct.pack("!I", random.getrandbits(32))

pong_waiter = self.loop.create_future()
pong_received = self.loop.create_future()
# The event loop's default clock is time.monotonic(). Its resolution
# is a bit low on Windows (~16ms). This is improved in Python 3.13.
self.pong_waiters[data] = (pong_waiter, self.loop.time())
self.pending_pings[data] = (pong_received, self.loop.time())
self.protocol.send_ping(data)
return pong_waiter
return pong_received

async def pong(self, data: Data = b"") -> None:
"""
Expand Down Expand Up @@ -757,7 +757,7 @@ def acknowledge_pings(self, data: bytes) -> None:

"""
# Ignore unsolicited pong.
if data not in self.pong_waiters:
if data not in self.pending_pings:
return

pong_timestamp = self.loop.time()
Expand All @@ -766,20 +766,20 @@ def acknowledge_pings(self, data: bytes) -> None:
# Acknowledge all previous pings too in that case.
ping_id = None
ping_ids = []
for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items():
for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items():
ping_ids.append(ping_id)
latency = pong_timestamp - ping_timestamp
if not pong_waiter.done():
pong_waiter.set_result(latency)
if not pong_received.done():
pong_received.set_result(latency)
if ping_id == data:
self.latency = latency
break
else:
raise AssertionError("solicited pong not found in pings")

# Remove acknowledged pings from self.pong_waiters.
# Remove acknowledged pings from self.pending_pings.
for ping_id in ping_ids:
del self.pong_waiters[ping_id]
del self.pending_pings[ping_id]

def abort_pings(self) -> None:
"""
Expand All @@ -791,16 +791,16 @@ def abort_pings(self) -> None:
assert self.protocol.state is CLOSED
exc = self.protocol.close_exc

for pong_waiter, _ping_timestamp in self.pong_waiters.values():
if not pong_waiter.done():
pong_waiter.set_exception(exc)
for pong_received, _ping_timestamp in self.pending_pings.values():
if not pong_received.done():
pong_received.set_exception(exc)
# If the exception is never retrieved, it will be logged when ping
# is garbage-collected. This is confusing for users.
# Given that ping is done (with an exception), canceling it does
# nothing, but it prevents logging the exception.
pong_waiter.cancel()
pong_received.cancel()

self.pong_waiters.clear()
self.pending_pings.clear()

async def keepalive(self) -> None:
"""
Expand All @@ -821,7 +821,7 @@ async def keepalive(self) -> None:
# connection to be closed before raising ConnectionClosed.
# However, connection_lost() cancels keepalive_task before
# it gets a chance to resume excuting.
pong_waiter = await self.ping()
pong_received = await self.ping()
if self.debug:
self.logger.debug("% sent keepalive ping")

Expand All @@ -830,9 +830,9 @@ async def keepalive(self) -> None:
async with asyncio_timeout(self.ping_timeout):
# connection_lost cancels keepalive immediately
# after setting a ConnectionClosed exception on
# pong_waiter. A CancelledError is raised here,
# pong_received. A CancelledError is raised here,
# not a ConnectionClosed exception.
latency = await pong_waiter
latency = await pong_received
self.logger.debug("% received keepalive pong")
except asyncio.TimeoutError:
if self.debug:
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def broadcast(
if connection.protocol.state is not OPEN:
continue

if connection.fragmented_send_waiter is not None:
if connection.send_in_progress is not None:
if raise_exceptions:
exception = ConcurrencyError("sending a fragmented message")
exceptions.append(exception)
Expand Down
3 changes: 1 addition & 2 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ class Assembler:

"""

# coverage reports incorrectly: "line NN didn't jump to the function exit"
def __init__( # pragma: no cover
def __init__(
self,
high: int | None = None,
low: int | None = None,
Expand Down
43 changes: 22 additions & 21 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
self.send_in_progress = False

# Mapping of ping IDs to pong waiters, in chronological order.
self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {}
self.pending_pings: dict[bytes, tuple[threading.Event, float, bool]] = {}

self.latency: float = 0
"""
Expand Down Expand Up @@ -629,8 +629,9 @@ def ping(

::

pong_event = ws.ping()
pong_event.wait() # only if you want to wait for the pong
pong_received = ws.ping()
# only if you want to wait for the corresponding pong
pong_received.wait()

Raises:
ConnectionClosed: When the connection is closed.
Expand All @@ -647,17 +648,17 @@ def ping(

with self.send_context():
# Protect against duplicates if a payload is explicitly set.
if data in self.pong_waiters:
if data in self.pending_pings:
raise ConcurrencyError("already waiting for a pong with the same data")

# Generate a unique random payload otherwise.
while data is None or data in self.pong_waiters:
while data is None or data in self.pending_pings:
data = struct.pack("!I", random.getrandbits(32))

pong_waiter = threading.Event()
self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close)
pong_received = threading.Event()
self.pending_pings[data] = (pong_received, time.monotonic(), ack_on_close)
self.protocol.send_ping(data)
return pong_waiter
return pong_received

def pong(self, data: Data = b"") -> None:
"""
Expand Down Expand Up @@ -707,7 +708,7 @@ def acknowledge_pings(self, data: bytes) -> None:
"""
with self.protocol_mutex:
# Ignore unsolicited pong.
if data not in self.pong_waiters:
if data not in self.pending_pings:
return

pong_timestamp = time.monotonic()
Expand All @@ -717,21 +718,21 @@ def acknowledge_pings(self, data: bytes) -> None:
ping_id = None
ping_ids = []
for ping_id, (
pong_waiter,
pong_received,
ping_timestamp,
_ack_on_close,
) in self.pong_waiters.items():
) in self.pending_pings.items():
ping_ids.append(ping_id)
pong_waiter.set()
pong_received.set()
if ping_id == data:
self.latency = pong_timestamp - ping_timestamp
break
else:
raise AssertionError("solicited pong not found in pings")

# Remove acknowledged pings from self.pong_waiters.
# Remove acknowledged pings from self.pending_pings.
for ping_id in ping_ids:
del self.pong_waiters[ping_id]
del self.pending_pings[ping_id]

def acknowledge_pending_pings(self) -> None:
"""
Expand All @@ -740,11 +741,11 @@ def acknowledge_pending_pings(self) -> None:
"""
assert self.protocol.state is CLOSED

for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values():
for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values():
if ack_on_close:
pong_waiter.set()
pong_received.set()

self.pong_waiters.clear()
self.pending_pings.clear()

def keepalive(self) -> None:
"""
Expand All @@ -762,15 +763,14 @@ def keepalive(self) -> None:
break

try:
pong_waiter = self.ping(ack_on_close=True)
pong_received = self.ping(ack_on_close=True)
except ConnectionClosed:
break
if self.debug:
self.logger.debug("% sent keepalive ping")

if self.ping_timeout is not None:
#
if pong_waiter.wait(self.ping_timeout):
if pong_received.wait(self.ping_timeout):
if self.debug:
self.logger.debug("% received keepalive pong")
else:
Expand Down Expand Up @@ -804,7 +804,7 @@ def recv_events(self) -> None:

Run this method in a thread as long as the connection is alive.

``recv_events()`` exits immediately when the ``self.socket`` is closed.
``recv_events()`` exits immediately when ``self.socket`` is closed.

"""
try:
Expand Down Expand Up @@ -979,6 +979,7 @@ def send_context(
# Minor layering violation: we assume that the connection
# will be closing soon if it isn't in the expected state.
wait_for_close = True
# TODO: calculate close deadline if not set?
raise_close_exc = True

# To avoid a deadlock, release the connection lock by exiting the
Expand Down
Empty file added src/websockets/trio/__init__.py
Empty file.
Loading