Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async support for subscribing messages #375

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions launch_testing_ros/launch_testing_ros/repeater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2019 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import rclpy
from rclpy.node import Node

from std_msgs.msg import String


class Repeater(Node):

def __init__(self):
super().__init__('repeater')
self.count = 0
self.subscription = self.create_subscription(
String, 'input', self.callback, 10
)
self.publisher = self.create_publisher(String, 'output', 10)

def callback(self, input_msg):
self.get_logger().info('I heard: [%s]' % input_msg.data)
if input_msg.data == "Hello":
reply = input_msg.data + " World"
self.publish_data(reply)
elif input_msg.data == "Knock Knock":
reply = "Who's there?"
self.publish_data(reply)

def publish_data(self, output_msg_data):
self.get_logger().info('Publishing: "{0}"'.format(output_msg_data))
self.publisher.publish(String(data=output_msg_data))

def main(args=None):
rclpy.init(args=args)

node = Repeater()

try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()


if __name__ == '__main__':
main()
31 changes: 26 additions & 5 deletions launch_testing_ros/launch_testing_ros/wait_for_topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
import random
import string
from threading import Condition
from threading import Event
from threading import Thread

Expand All @@ -26,7 +28,7 @@ class WaitForTopics:
"""
Wait to receive messages on supplied topics.

Example usage:
Example usage for periodic topics:
--------------

from std_msgs.msg import String
Expand All @@ -47,9 +49,15 @@ def method_2():
print(wait_for_topics.topics_not_received()) # Should be an empty set
print(wait_for_topics.topics_received()) # Should be {'topic_1', 'topic_2'}
wait_for_topics.shutdown()


Example usage for async topics:
--------------

See test/examples/translator_test.py
"""

def __init__(self, topic_tuples, timeout=5.0):
def __init__(self, topic_tuples, timeout=5.0, *, start_subscribers=False):
self.topic_tuples = topic_tuples
self.timeout = timeout
self.__ros_context = rclpy.Context()
Expand All @@ -63,18 +71,25 @@ def __init__(self, topic_tuples, timeout=5.0):
self.__ros_spin_thread = Thread(target=self._spin_function)
self.__ros_spin_thread.start()

if start_subscribers:
self._start_subscribers()

def _prepare_ros_node(self):
node_name = '_test_node_' +\
''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
self.__ros_node = _WaitForTopicsNode(name=node_name, node_context=self.__ros_context)
self.__ros_executor.add_node(self.__ros_node)

def _start_subscribers(self):
self.__ros_node.start_subscribers(self.topic_tuples)

def _spin_function(self):
while self.__running:
self.__ros_executor.spin_once(1.0)

def wait(self):
self.__ros_node.start_subscribers(self.topic_tuples)
def wait(self, *, start_subscribers=True):
if start_subscribers:
self._start_subscribers()
return self.__ros_node.msg_event_object.wait(self.timeout)

def shutdown(self):
Expand All @@ -87,6 +102,10 @@ def topics_received(self):
"""Topics that received at least one message."""
return self.__ros_node.received_topics

def messages_received(self):
"""List of messages that receive, keyed by topic."""
return self.__ros_node.received_messages

def topics_not_received(self):
"""Topics that did not receive any messages."""
return self.__ros_node.expected_topics - self.__ros_node.received_topics
Expand Down Expand Up @@ -114,6 +133,7 @@ def start_subscribers(self, topic_tuples):
self.subscriber_list = []
self.expected_topics = {name for name, _ in topic_tuples}
self.received_topics = set()
self.received_messages = defaultdict(list)

for topic_name, topic_type in topic_tuples:
# Create a subscriber
Expand All @@ -128,7 +148,8 @@ def start_subscribers(self, topic_tuples):

def callback_template(self, topic_name):

def topic_callback(data):
def topic_callback(msg):
self.received_messages[topic_name].append(msg)
if topic_name not in self.received_topics:
self.get_logger().debug('Message received for ' + topic_name)
self.received_topics.add(topic_name)
Expand Down
3 changes: 3 additions & 0 deletions launch_testing_ros/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
],
entry_points={
'pytest11': ['launch_ros = launch_testing_ros_pytest_entrypoint'],
'console_scripts': [
'repeater = launch_testing_ros.repeater:main'
]
},
install_requires=['setuptools'],
zip_safe=True,
Expand Down
110 changes: 110 additions & 0 deletions launch_testing_ros/test/examples/repeater_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2023 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import sys
import unittest

import launch
import launch_ros
import launch_testing.actions
from launch_testing_ros.wait_for_topics import WaitForTopics
from launch_testing_ros import repeater

import pytest

import rclpy

from std_msgs.msg import String


@pytest.mark.rostest
def generate_test_description():

# path_to_test = os.path.dirname(__file__)
repeater_node = launch_ros.actions.Node(
package="launch_testing_ros",
executable="repeater",
# arguments=[os.path.join(path_to_test, 'repeater.py')],
additional_env={'PYTHONUNBUFFERED': '1'}
)

return (
launch.LaunchDescription([
repeater_node,
# Start tests right away - no need to wait for anything
launch_testing.actions.ReadyToTest(),
]),
{
'repeater': repeater_node
}
)


class Testrepeater(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Initialize the ROS context for the test node
rclpy.init()

@classmethod
def tearDownClass(cls):
# Shutdown the ROS context
rclpy.shutdown()

def setUp(self):
# Create a ROS node for tests
self.node = rclpy.create_node('test_repeater')
self.publisher = self.node.create_publisher(String, "input", qos_profile=10)
self.topic_list = [('output', String)]
self.wait_for_topics = WaitForTopics(self.topic_list, timeout=10.0, start_subscribers=True)

def tearDown(self):
self.node.destroy_node()

def test_repeater_translates_hello(self, launch_service, repeater, proc_output):
# Expect the repeater to reply with "Hello World"
msg = String()
msg.data = "Hello"
self.assertTrue(self.wait_for_topics.wait(start_subscribers=False))
self.publisher.publish(msg)

received_output = self.wait_for_topics.messages_received()['output']
self.assertEqual(len(received_output), 1)
self.assertEqual(received_output[0].data, "Hello World")

def test_repeater_translates_knock(self, launch_service, repeater, proc_output):
# Expect the repeater to reply with "Who's There?"
msg = String()
msg.data = "Knock Knock"

self.assertTrue(self.wait_for_topics.wait(start_subscribers=False))
self.publisher.publish(msg)

received_output = self.wait_for_topics.messages_received()['output']
self.assertEqual(len(received_output), 1)
self.assertEqual(received_output[0].data, "Who's there")

def test_repeater_ignores_foobar(self, launch_service, repeater, proc_output):
# Expect the repeater to not reply"
msg = String()
msg.data = "FooBar"

self.assertFalse(self.wait_for_topics.wait(start_subscribers=False))
self.publisher.publish(msg)

self.assertFalse('output' in self.wait_for_topics.messages_received())
self.assertEqual(self.wait_for_topics.topics_not_received(), self.topic_list)