Skip to content

Commit

Permalink
fix: Now monitor sockets work as expected
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Sep 12, 2023
1 parent 7997eeb commit a0e3b76
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 43 deletions.
7 changes: 4 additions & 3 deletions examples/simple-client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def single_client() -> None:
peer = Peer(
connect=ZeroMQAddress("tcp://localhost:5020"),
transport=ZeroMQRPCTransport,
authenticator=None,
transport_opts={"attach_monitor": True},
serializer=lambda o: json.dumps(o).encode("utf8"),
deserializer=lambda b: json.loads(b),
invoke_timeout=2.0,
Expand All @@ -110,6 +110,7 @@ async def overlapped_requests() -> None:
peer = Peer(
connect=ZeroMQAddress("tcp://localhost:5020"),
transport=ZeroMQRPCTransport,
transport_opts={"attach_monitor": True},
serializer=lambda o: json.dumps(o).encode("utf8"),
deserializer=lambda b: json.loads(b),
invoke_timeout=2.0,
Expand All @@ -135,7 +136,7 @@ async def multi_clients() -> None:
peer = Peer(
connect=ZeroMQAddress("tcp://localhost:5020"),
transport=ZeroMQRPCTransport,
authenticator=None,
transport_opts={"attach_monitor": True},
serializer=lambda o: json.dumps(o).encode("utf8"),
deserializer=lambda b: json.loads(b),
invoke_timeout=2.0,
Expand All @@ -162,7 +163,7 @@ async def multi_clients() -> None:
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
level=logging.INFO,
level=logging.DEBUG,
)
log = logging.getLogger()

Expand Down
4 changes: 2 additions & 2 deletions examples/simple-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def serve(scheduler_type: str) -> None:
peer = Peer(
bind=ZeroMQAddress("tcp://127.0.0.1:5020"),
transport=ZeroMQRPCTransport,
authenticator=None,
transport_opts={"attach_monitor": True},
scheduler=scheduler,
serializer=lambda o: json.dumps(o).encode("utf8"),
deserializer=lambda b: json.loads(b),
Expand Down Expand Up @@ -156,7 +156,7 @@ def main(scheduler_type):
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
level=logging.INFO,
level=logging.DEBUG,
)
log = logging.getLogger()
main()
56 changes: 18 additions & 38 deletions src/callosum/lower/zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import logging
import secrets
import time
import warnings
from typing import Any, AsyncGenerator, ClassVar, Mapping, Optional, Type

import attrs
import zmq
import zmq.asyncio
import zmq.constants
import zmq.utils.monitor
from zmq.utils import z85

from callosum.exceptions import AuthenticationError

from ..abc import RawHeaderBody
from ..auth import (
AbstractClientAuthenticator,
Expand Down Expand Up @@ -261,21 +263,14 @@ class ZeroMQMonitorMixin:

_monitor_sock: Optional[zmq.asyncio.Socket]

# FIXME: Upon release pyzmq 23.0 or 22.4, take the constant declarations
# from the zmq.constants.Event enum class, instead of doing dir().
EVENT_MAP = {
getattr(zmq.constants, name): name[6:].replace("_", "-").lower()
for name in dir(zmq.constants)
if name.startswith("EVENT_")
}

async def _monitor(self) -> None:
assert self._monitor_sock is not None
log = logging.getLogger("callosum.lower.zeromq.monitor")
try:
while await self._monitor_sock.poll():
raw_msg = await self._monitor_sock.recv_multipart()
msg = zmq.utils.monitor.parse_monitor_message(raw_msg)
while True:
msg = await zmq.utils.monitor.recv_monitor_message(
self._monitor_sock
)
log.debug("monitor[%s] event: %r", self.addr, msg)
if msg["event"] == zmq.EVENT_MONITOR_STOPPED:
break
Expand Down Expand Up @@ -314,12 +309,6 @@ def __init__(
super().__init__(transport, addr)
self._attach_monitor = attach_monitor
self._zsock_opts = {**_default_zsock_opts, **(zsock_opts or {})}
if attach_monitor:
warnings.warn(
"ZeroMQ async monitor socket support is buggy "
"and not recommended to use.",
RuntimeWarning,
)

async def ping(self, ping_timeout: int = 1000) -> bool:
assert self._main_sock is not None
Expand All @@ -340,9 +329,7 @@ async def __aenter__(self):
server_sock.setsockopt(key, value)
if self._attach_monitor:
monitor_addr = f"inproc://monitor-{secrets.token_hex(16)}"
server_sock.get_monitor_socket(addr=monitor_addr)
self._monitor_sock = self.transport._zctx.socket(zmq.PAIR)
self._monitor_sock.connect(monitor_addr)
self._monitor_sock = server_sock.get_monitor_socket(addr=monitor_addr)
self._monitor_task = asyncio.create_task(self._monitor())
else:
self._monitor_sock = None
Expand Down Expand Up @@ -389,21 +376,13 @@ def __init__(
super().__init__(transport, addr)
self._attach_monitor = attach_monitor
self._zsock_opts = {**_default_zsock_opts, **(zsock_opts or {})}
if attach_monitor:
warnings.warn(
"ZeroMQ async monitor socket support is buggy "
"and not recommended to use.",
RuntimeWarning,
)

async def ping(self, ping_timeout: int = 100) -> bool:
async def ping(self, ping_timeout: int = 1000) -> bool:
assert self._main_sock is not None
sock: zmq.asyncio.Socket = self._main_sock
await sock.send_multipart([b"PING", b"", b""])
ret = await sock.poll(ping_timeout)
if ret == 0:
return False
response = await sock.recv_multipart()
async with asyncio.timeout(ping_timeout / 1000):
response = await sock.recv_multipart()
return response[0] == b"PONG"

async def __aenter__(self) -> ZeroMQRPCConnection:
Expand All @@ -416,30 +395,31 @@ async def __aenter__(self) -> ZeroMQRPCConnection:
client_sock.setsockopt(key, value)
if self._attach_monitor:
monitor_addr = f"inproc://monitor-{secrets.token_hex(16)}"
client_sock.get_monitor_socket(addr=monitor_addr)
self._monitor_sock = self.transport._zctx.socket(zmq.PAIR)
self._monitor_sock.connect(monitor_addr)
self._monitor_sock = client_sock.get_monitor_socket(addr=monitor_addr)
self._monitor_task = asyncio.create_task(self._monitor())
else:
self._monitor_sock = None
self._monitor_task = None
client_sock.connect(self.addr.uri)
self._main_sock = client_sock
self.transport._sock = client_sock
# if not await self.ping():
# raise AuthenticationError
try:
await self.ping()
except asyncio.TimeoutError:
raise AuthenticationError
handshake_done = time.perf_counter()
log.debug(
"ZeroMQ connector handshake latency: %.3f sec",
handshake_done - handshake_begin,
)
return ZeroMQRPCConnection(self.transport)

async def __aexit__(self, exc_type, exc_obj, exc_tb):
async def __aexit__(self, exc_type, exc_obj, exc_tb) -> Optional[bool]:
assert self._main_sock is not None
if self._monitor_task is not None:
self._main_sock.disable_monitor()
await self._monitor_task
return None


class ZeroMQRouterBinder(ZeroMQBaseBinder):
Expand Down

0 comments on commit a0e3b76

Please sign in to comment.