Skip to content

Commit

Permalink
Subscriptions are message feeds (#99)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai authored May 31, 2024
1 parent ac29b5d commit e618797
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 79 deletions.
101 changes: 101 additions & 0 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

from typing import Any, Callable, Iterator, List, Optional

from message_filters import SimpleFilter
from rclpy.node import Node
from rclpy.task import Future

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.utilities import Tape


class MessageFeed:
"""An ergonomic wrapper for generic message filters."""

def __init__(
self,
link: SimpleFilter,
*,
history_length: Optional[int] = None,
node: Optional[Node] = None,
) -> None:
"""Initializes the message feed.
Args:
link: Wrapped message filter, connecting this message feed with its source.
history_length: optional historic data size, defaults to 1
node: optional node for the underlying native subscription, defaults to
the current process node.
"""
if node is None:
node = scope.ensure_node()
if history_length is None:
history_length = 1
self._link = link
self._tape = Tape(history_length)
self._link.registerCallback(self._tape.write)
node.context.on_shutdown(self._tape.close)

@property
def link(self) -> SimpleFilter:
"""Gets the underlying message connection."""
return self._link

@property
def history(self) -> List[Any]:
"""Gets the entire history of messages received so far."""
return list(self._tape.content())

@property
def latest(self) -> Optional[Any]:
"""Gets the latest message received, if any."""
return next(self._tape.content(), None)

@property
def update(self) -> Future:
"""Gets the future to the next message yet to be received."""
return self._tape.future_write

def matching_update(self, matching_predicate: Callable[[Any], bool]) -> Future:
"""Gets a future to the next matching message yet to be received.
Args:
matching_predicate: a boolean predicate to match incoming messages.
Returns:
a future.
"""
return self._tape.future_matching_write(matching_predicate)

def stream(
self,
*,
forward_only: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
) -> Iterator[Any]:
"""Iterates over messages as they come.
Iteration stops when the given timeout expires or when the associated context
is shutdown. Note that iterating over the message stream is a blocking operation.
Args:
forward_only: whether to ignore previosuly received messages.
buffer_size: optional maximum size for the incoming messages buffer.
If none is provided, the buffer will be grow unbounded.
timeout_sec: optional timeout, in seconds, for a new message to be received.
Returns:
a lazy iterator over messages.
"""
return self._tape.content(
follow=True,
forward_only=forward_only,
buffer_size=buffer_size,
timeout_sec=timeout_sec,
)

def close(self) -> None:
"""Closes the message feed."""
self._tape.close()
117 changes: 38 additions & 79 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.

from collections.abc import Sequence
from typing import Any, Callable, Iterator, Optional, TypeVar, Union
from typing import Any, Optional, Type, Union, cast

import message_filters
from message_filters import ApproximateTimeSynchronizer, Subscriber
from rclpy.callback_groups import CallbackGroup
from rclpy.node import Node
from rclpy.qos import QoSProfile
from rclpy.task import Future

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.feeds import MessageFeed
from bdai_ros2_wrappers.futures import wait_for_future
from bdai_ros2_wrappers.utilities import Tape
from bdai_ros2_wrappers.type_hints import Msg as MessageT

MessageT = TypeVar("MessageT")

class Subscription(MessageFeed):
"""An ergonomic interface for topic subscriptions in ROS 2.
class Subscription:
"""An ergonomic interface to for topic subscription in ROS 2.
Subscription instances wrap `rclpy.Subscription` instances and allow
synchronous and asynchronous iteration and fetching of published data.
Subscription instances are `MessageFeed` instances wrapping `message_filters.Subscriber`
instances and thus allow synchronous and asynchronous iteration and fetching of published data.
"""

def __init__(
self,
message_type: MessageT,
message_type: Type[MessageT],
topic_name: str,
qos_profile: Optional[Union[QoSProfile, int]] = None,
history_length: Optional[int] = None,
Expand All @@ -47,84 +46,45 @@ def __init__(
"""
if node is None:
node = scope.ensure_node()
self._node = node
if history_length is None:
history_length = 1
if qos_profile is None:
qos_profile = 1
super().__init__(
Subscriber(
node,
message_type,
topic_name,
qos_profile=qos_profile,
**kwargs,
),
history_length=history_length,
node=node,
)
self._message_type = message_type
self._topic_name = topic_name
self._message_tape = Tape(history_length)
self._topic_subscription = self._node.create_subscription(
message_type,
topic_name,
self._message_tape.write,
qos_profile,
**kwargs,
)
self._node.context.on_shutdown(self._message_tape.close)

@property
def history(self) -> Sequence[Any]:
"""Gets the entire history of messages received so far."""
return list(self._message_tape.content())
self._node = node

@property
def latest(self) -> Optional[Any]:
"""Gets the latest message received, if any."""
return next(self._message_tape.content(), None)
def message_type(self) -> Type[MessageT]:
"""Gets the type of the message subscribed."""
return self._message_type

@property
def update(self) -> Future:
"""Gets the future to the next message yet to be received."""
return self._message_tape.future_write

def matching_update(self, matching_predicate: Callable[[Any], bool]) -> Future:
"""Gets a future to the next matching message yet to be received.
Args:
matching_predicate: a boolean predicate to match incoming messages.
Returns:
a future.
"""
return self._message_tape.future_matching_write(matching_predicate)

def stream(
self,
*,
forward_only: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
) -> Iterator[Any]:
"""Iterates over messages as they come.
Iteration stops when the given timeout expires or when the associated context
is shutdown. Note that iterating over the message stream is a blocking operation.
Args:
forward_only: whether to ignore previosuly received messages.
buffer_size: optional maximum size for the incoming messages buffer.
If none is provided, the buffer will be grow unbounded.
timeout_sec: optional timeout, in seconds, for a new message to be received.
Returns:
a lazy iterator over messages.
"""
return self._message_tape.content(
follow=True,
forward_only=forward_only,
buffer_size=buffer_size,
timeout_sec=timeout_sec,
def topic_name(self) -> str:
"""Gets the name of the topic subscribed."""
return self._topic_name

def close(self) -> None:
"""Closes the subscription."""
self._node.destroy_subscription(
cast(Subscriber, self.link).sub,
)
super().close()

def cancel(self) -> None:
"""Cancels the message subscription if not cancelled already."""
self._node.destroy_subscription(self._topic_subscription)
self._message_tape.close()

# Alias for improved readability
unsubscribe = cancel
# Aliases for improved readability
cancel = MessageFeed.close
unsubscribe = MessageFeed.close


def wait_for_message_async(
Expand Down Expand Up @@ -259,12 +219,11 @@ def wait_for_messages_async(
"""
if node is None:
node = scope.ensure_node()

if qos_profiles is None:
qos_profiles = [None] * len(topic_names)

subscribers = [
message_filters.Subscriber(
Subscriber(
node,
message_type,
topic_name,
Expand All @@ -281,7 +240,7 @@ def callback(*messages: Sequence[Any]) -> None:
if not future.done():
future.set_result(messages)

sync = message_filters.ApproximateTimeSynchronizer(
sync = ApproximateTimeSynchronizer(
subscribers,
queue_size,
delay,
Expand Down

0 comments on commit e618797

Please sign in to comment.