From 9d48222d03396f836e7079b3924dc3ce42ebe122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20B=C3=B6hm?= Date: Sun, 2 Jun 2024 02:49:03 +0200 Subject: [PATCH] feat: Draft router class --- aiomqtt/__init__.py | 2 ++ aiomqtt/client.py | 16 ++++++++++++++++ aiomqtt/router.py | 11 +++++++++++ aiomqtt/topic.py | 43 ++++++++++++++++++++++++++++--------------- 4 files changed, 57 insertions(+), 15 deletions(-) create mode 100644 aiomqtt/router.py diff --git a/aiomqtt/__init__.py b/aiomqtt/__init__.py index 24aeded..d166455 100644 --- a/aiomqtt/__init__.py +++ b/aiomqtt/__init__.py @@ -8,6 +8,7 @@ ) from .exceptions import MqttCodeError, MqttError, MqttReentrantError from .message import Message +from .router import Router from .topic import Topic, TopicLike, Wildcard, WildcardLike # These are placeholders that are managed by poetry-dynamic-versioning @@ -19,6 +20,7 @@ "__version_tuple__", "Client", "Message", + "Router", "ProtocolVersion", "ProxySettings", "TLSParameters", diff --git a/aiomqtt/client.py b/aiomqtt/client.py index 4e52517..af780a8 100644 --- a/aiomqtt/client.py +++ b/aiomqtt/client.py @@ -34,6 +34,7 @@ from .exceptions import MqttCodeError, MqttConnectError, MqttError, MqttReentrantError from .message import Message +from .router import Router from .types import ( P, PayloadType, @@ -134,6 +135,7 @@ class Client: password: The password to authenticate with. logger: Custom logger instance. identifier: The client identifier. Generated automatically if ``None``. + routers: A list of routers to route messages to. queue_type: The class to use for the queue. The default is ``asyncio.Queue``, which stores messages in FIFO order. For LIFO order, you can use ``asyncio.LifoQueue``; For priority order you can subclass @@ -186,6 +188,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 password: str | None = None, logger: logging.Logger | None = None, identifier: str | None = None, + routers: list[Router] | None = None, queue_type: type[asyncio.Queue[Message]] | None = None, protocol: ProtocolVersion | None = None, will: Will | None = None, @@ -250,6 +253,11 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 if protocol is None: protocol = ProtocolVersion.V311 + # List of routers with message handlers + if routers is None: + routers = [] + self._routers = routers + # Create the underlying paho-mqtt client instance self._client: mqtt.Client = mqtt.Client( callback_api_version=CallbackAPIVersion.VERSION1, @@ -453,6 +461,14 @@ async def publish( # noqa: PLR0913 # Wait for confirmation await self._wait_for(confirmation.wait(), timeout=timeout) + async def route(self, message: Message) -> None: + """Route a message to the appropriate handler.""" + for router in self._routers: + for wildcard, handler in router._handlers.items(): + with contextlib.suppress(ValueError): + # If we get a ValueError, we know that the topic doesn't match + await handler(message, self, *message.topic.extract(wildcard)) + async def _messages(self) -> AsyncGenerator[Message, None]: """Async generator that yields messages from the underlying message queue.""" while True: diff --git a/aiomqtt/router.py b/aiomqtt/router.py new file mode 100644 index 0000000..9965bbc --- /dev/null +++ b/aiomqtt/router.py @@ -0,0 +1,11 @@ +class Router: + def __init__(self) -> None: + self._handlers = {} + + def match(self, *args: str): + """Add a new handler with one or multiple wildcards to the router.""" + def decorator(func): + for wildcard in args: + self._handlers[wildcard] = func + return func + return decorator diff --git a/aiomqtt/topic.py b/aiomqtt/topic.py index b111bd3..28a119e 100644 --- a/aiomqtt/topic.py +++ b/aiomqtt/topic.py @@ -89,6 +89,21 @@ def matches(self, wildcard: WildcardLike) -> bool: Returns: True if the topic matches the wildcard, False otherwise. """ + try: + self.extract(wildcard) + return True + except ValueError: + return False + + def extract(self, wildcard: WildcardLike) -> list[str]: + """Extract the wildcard values from the topic. + + Args: + wildcard: The wildcard to match against. + + Returns: + A list of wildcard values extracted from the topic. + """ if not isinstance(wildcard, Wildcard): wildcard = Wildcard(wildcard) # Split topics into levels to compare them one by one @@ -98,21 +113,19 @@ def matches(self, wildcard: WildcardLike) -> bool: # Shared subscriptions use the topic structure: $share// wildcard_levels = wildcard_levels[2:] - def recurse(tl: list[str], wl: list[str]) -> bool: - """Recursively match topic levels with wildcard levels.""" - if not tl: - if not wl or wl[0] == "#": - return True - return False - if not wl: - return False - if wl[0] == "#": - return True - if tl[0] == wl[0] or wl[0] == "+": - return recurse(tl[1:], wl[1:]) - return False - - return recurse(topic_levels, wildcard_levels) + # Extract wildcard values from the topic + arguments = [] + for index, level in enumerate(wildcard_levels): + if level == "#": + return arguments + topic_levels[index:] + if len(topic_levels) == index: + raise ValueError("Topic does not match wildcard") + if level != "+" and level != topic_levels[index]: + raise ValueError("Topic does not match wildcard") + arguments.append(topic_levels[index]) + if len(topic_levels) > index + 1: + raise ValueError("Topic does not match wildcard") + return arguments TopicLike: TypeAlias = "str | Topic"