Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add publisher matching checks to Subscriptions #106

Merged
merged 3 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,56 @@ def __init__(
self._topic_name = topic_name
self._node = node

@property
def subscriber(self) -> Subscriber:
"""Gets the underlying subscriber.

Type-casted alias of `Subscription.link`.
"""
return cast(Subscriber, self.link)

def publisher_matches(self, num_publishers: int) -> Future:
"""Gets a future to next publisher matching status update.

Note that in ROS 2 Humble and ealier distributions, this method relies on
polling the number of known publishers for the topic subscribed, as subscription
matching events are missing.

Args:
num_publishers: lower bound on the number of publishers to match.

Returns:
a future, done if the current number of publishers already matches
the specified lower bound.
"""
future_match = Future()
num_matched_publishers = self._node.count_publishers(self._topic_name)
if num_matched_publishers < num_publishers:

def _poll_publisher_matches() -> None:
nonlocal future_match, num_publishers
if future_match.cancelled():
return
num_matched_publishers = self._node.count_publishers(self._topic_name)
if num_publishers <= num_matched_publishers:
future_match.set_result(num_matched_publishers)

timer = self._node.create_timer(0.1, _poll_publisher_matches)
future_match.add_done_callback(lambda _: self._node.destroy_timer(timer))
else:
future_match.set_result(num_matched_publishers)
return future_match

@property
def matched_publishers(self) -> int:
"""Gets the number publishers matched and linked to.

Note that in ROS 2 Humble and earlier distributions, this property
relies on the number of known publishers for the topic subscribed
as subscription matching status info is missing.
"""
return self._node.count_publishers(self._topic_name)

@property
def message_type(self) -> Type[MessageT]:
"""Gets the type of the message subscribed."""
Expand All @@ -77,9 +127,7 @@ def topic_name(self) -> str:

def close(self) -> None:
"""Closes the subscription."""
self._node.destroy_subscription(
cast(Subscriber, self.link).sub,
)
self._node.destroy_subscription(self.subscriber.sub)
super().close()

# Aliases for improved readability
Expand Down
21 changes: 6 additions & 15 deletions bdai_ros2_wrappers/test/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.
import logging
import threading
from typing import List, Optional
from typing import List

from rcl_interfaces.msg import Log
from rclpy.clock import ROSClock
from rclpy.task import Future
from rclpy.time import Time

from bdai_ros2_wrappers.futures import wait_for_future
from bdai_ros2_wrappers.futures import unwrap_future
from bdai_ros2_wrappers.logging import LoggingSeverity, as_memoizing_logger, logs_to_ros
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.subscription import Subscription


def test_memoizing_logger(verbose_ros: ROSAwareScope) -> None:
Expand Down Expand Up @@ -76,26 +76,17 @@ def all_messages_arrived() -> bool:


def test_log_forwarding(verbose_ros: ROSAwareScope) -> None:
future: Optional[Future] = None

def callback(message: Log) -> None:
nonlocal future
if future and not future.done():
future.set_result(message)

assert verbose_ros.node is not None
verbose_ros.node.create_subscription(Log, "/rosout", callback, 10)
rosout = Subscription(Log, "/rosout", 10, node=verbose_ros.node)
assert unwrap_future(rosout.publisher_matches(1), timeout_sec=5.0) > 0

future = Future()
with logs_to_ros(verbose_ros.node):
logger = logging.getLogger("my_logger")
logger.setLevel(logging.INFO)
logger.propagate = True # ensure propagation is enabled
logger.info("test")

assert wait_for_future(future, timeout_sec=10)
assert future.done()
log = future.result()
log = unwrap_future(rosout.update, timeout_sec=5.0)
# NOTE(hidmic) why are log levels of bytestring type !?
assert log.level == int.from_bytes(Log.INFO, byteorder="little")
assert log.name == verbose_ros.node.get_logger().name
Expand Down
22 changes: 22 additions & 0 deletions bdai_ros2_wrappers/test/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,27 @@ def deferred_publish() -> None:
assert message.data == 1


def test_subscription_matching_publishers(ros: ROSAwareScope) -> None:
"""Asserts that checking for publisher matching on a subscription works as expected."""
assert ros.node is not None
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)
assert sequence.matched_publishers == 0
future = sequence.publisher_matches(1)
assert not future.done()
future.cancel()

ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1


def test_subscription_future_wait(ros: ROSAwareScope) -> None:
"""Asserts that waiting for a subscription update works as expected."""
assert ros.node is not None
pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

pub.publish(Int8(data=1))

Expand All @@ -53,6 +69,8 @@ def test_subscription_matching_future_wait(ros: ROSAwareScope) -> None:
assert ros.node is not None
pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

def deferred_publish() -> None:
time.sleep(0.5)
Expand Down Expand Up @@ -84,6 +102,8 @@ def test_subscription_iteration(ros: ROSAwareScope) -> None:
history_length=3,
node=ros.node,
)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

expected_sequence_numbers = [1, 10, 100]

Expand All @@ -108,6 +128,8 @@ def test_subscription_cancelation(ros: ROSAwareScope) -> None:
assert ros.node is not None
pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

pub.publish(Int8(data=1))

Expand Down
Loading