diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py index c8c7494..30a162a 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py @@ -8,7 +8,7 @@ from rclpy.task import Future import bdai_ros2_wrappers.scope as scope -from bdai_ros2_wrappers.filters import SimpleAdapter, TransformFilter +from bdai_ros2_wrappers.filters import SimpleAdapter, TransformFilter, Tunnel from bdai_ros2_wrappers.utilities import Tape @@ -69,6 +69,16 @@ def matching_update(self, matching_predicate: Callable[[Any], bool]) -> Future: """ return self._tape.future_matching_write(matching_predicate) + def recall(self, callback: Callable) -> Tunnel: + """Adds a callback for message recalling. + + Returns: + the underlying connection, which can be closed to stop future callbacks. + """ + tunnel = Tunnel(self.link) + tunnel.registerCallback(callback) + return tunnel + def stream( self, *, diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py index e58b4c2..493d343 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py @@ -25,7 +25,7 @@ class TransformFilter(SimpleFilter): def __init__( self, - f: SimpleFilter, + upstream: SimpleFilter, target_frame_id: str, tf_buffer: tf2_ros.Buffer, tolerance_sec: float, @@ -34,7 +34,7 @@ def __init__( """Initializes the transform filter. Args: - f: the upstream message filter. + upstream: the upstream message filter. target_frame_id: the target frame ID for transforms. tf_buffer: a buffer of transforms to look up. tolerance_sec: a tolerance, in seconds, to wait for late transforms @@ -50,7 +50,7 @@ def __init__( self.target_frame_id = target_frame_id self.tf_buffer = tf_buffer self.tolerance = Duration(seconds=tolerance_sec) - self.incoming_connection = f.registerCallback(self.add) + self.connection = upstream.registerCallback(self.add) def _wait_callback(self, messages: Sequence[Any], future: Future) -> None: if future.cancelled(): @@ -120,17 +120,35 @@ def add(self, *messages: Any) -> None: class SimpleAdapter(SimpleFilter): """A message filter for data adaptation.""" - def __init__(self, f: SimpleFilter, fn: Callable) -> None: + def __init__(self, upstream: SimpleFilter, fn: Callable) -> None: """Initializes the adapter. Args: - f: the upstream message filter. + upstream: the upstream message filter. fn: adapter implementation as a callable. """ super().__init__() self.do_adapt = fn - self.incoming_connection = f.registerCallback(self.add) + self.connection = upstream.registerCallback(self.add) def add(self, *messages: Any) -> None: """Adds new `messages` to the adapter.""" self.signalMessage(self.do_adapt(*messages)) + + +class Tunnel(SimpleFilter): + """A message filter that simply forwards messages but can be detached.""" + + def __init__(self, upstream: SimpleFilter) -> None: + """Initializes the tunnel. + + Args: + upstream: the upstream message filter. + """ + super().__init__() + self.upstream = upstream + self.connection = upstream.registerCallback(self.signalMessage) + + def close(self) -> None: + """Closes the tunnel, disconnecting it from upstream.""" + del self.upstream.callbacks[self.connection] diff --git a/bdai_ros2_wrappers/test/test_feeds.py b/bdai_ros2_wrappers/test/test_feeds.py index 6450c7b..bdd71f5 100644 --- a/bdai_ros2_wrappers/test/test_feeds.py +++ b/bdai_ros2_wrappers/test/test_feeds.py @@ -1,5 +1,6 @@ # Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. +from typing import Optional import tf2_ros from geometry_msgs.msg import ( @@ -94,3 +95,31 @@ def test_adapted_message_feed(ros: ROSAwareScope) -> None: position_message = ensure(position_message_feed.latest) # no copies are expected, thus an identity check is valid assert position_message is expected_pose_message.pose.position + + +def test_message_feed_recalls(ros: ROSAwareScope) -> None: + pose_message_feed = MessageFeed(SimpleFilter()) + + latest_message: Optional[PoseStamped] = None + + def callback(message: PoseStamped) -> None: + nonlocal latest_message + latest_message = message + + conn = pose_message_feed.recall(callback) + + first_pose_message = PoseStamped() + first_pose_message.header.stamp.sec = 1 + pose_message_feed.link.signalMessage(first_pose_message) + + assert latest_message is not None + assert latest_message.header.stamp.sec == first_pose_message.header.stamp.sec + + conn.close() + + second_pose_message = PoseStamped() + second_pose_message.header.stamp.sec = 2 + pose_message_feed.link.signalMessage(second_pose_message) + + assert latest_message.header.stamp.sec != second_pose_message.header.stamp.sec + assert latest_message.header.stamp.sec == first_pose_message.header.stamp.sec