-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbase_server.py
122 lines (95 loc) · 4.03 KB
/
base_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from abc import ABC, abstractmethod
from typing import Protocol
import numpy as np
import rospy
from std_msgs.msg import Int16
from rosnav_rl.rl_agent import RL_Agent
from rosnav_rl.srv import GetAction, GetActionResponse
from rosnav_rl.utils.rostopic import Namespace
from rosnav_rl.utils.type_aliases import ObservationDict
class ObservationCollector(Protocol):
def get_observations(self, *args, **kwargs) -> ObservationDict: ...
class ActionServer(ABC):
"""ActionServer is an abstract base class for a ROS action server that interacts with a reinforcement learning agent.
Attributes:
agent (RL_Agent): The reinforcement learning agent.
observation_collector (ObservationCollector): The observation collector for gathering environment observations.
"""
agent: RL_Agent = None
observation_collector: ObservationCollector = None
def __init__(self, agent_name: str, namespace: str = "") -> None:
"""
Initializes the BaseServer.
Args:
model_path (str): The path to the model file.
namespace (str, optional): The namespace for the server. Defaults to an empty string.
"""
self.agent_name = agent_name
self.namespace = Namespace(namespace)
@abstractmethod
def _initialize_agent(self) -> RL_Agent: ...
@abstractmethod
def _initialize_observation_collector(self) -> ObservationCollector: ...
def _initialize_ros(self):
"""
Initializes ROS services and subscribers for the action server.
This method sets up the following ROS components:
- A service to get the next action, which is handled by `__handle_next_action_srv`.
- A subscriber to reset the stacked observations, which listens to the "/scenario_reset" topic and calls `__on_scene_reset`.
Returns:
None
"""
self._get_next_action_srv = rospy.Service(
str(self.namespace("rosnav_rl/get_action")),
GetAction,
self.__handle_next_action_srv,
)
self._sub_reset_stacked_obs = rospy.Subscriber(
"/scenario_reset", Int16, self.__on_scene_reset
)
def __handle_next_action_srv(self, request: GetAction):
"""
Handles the service request to get the next action.
Args:
request (GetAction): The service request.
Returns:
GetActionResponse: The service response containing the next action.
"""
response = GetActionResponse()
response.action = np.array([0, 0, 0])
if self.agent is None:
rospy.loginfo("Agent not initialized yet.")
return response
action = self.agent.get_action(self.observation_collector.get_observations())
response.action = action
return response
def __on_scene_reset(self, request: Int16):
"""
Resets the last action and stacked observations.
Args:
request (Int16): The reset request.
Returns:
None
"""
if self.agent is None:
rospy.loginfo("Agent not initialized yet.")
return
self.agent.model.reset()
def start(self):
"""
Starts the ROS node and initializes the agent and observation collector.
This method performs the following steps:
1. Initializes ROS-related components.
2. Initializes the agent.
3. Initializes the observation collector.
4. Enters a loop that keeps the node running until ROS is shut down.
"""
self._initialize_ros()
rospy.loginfo("[Rosnav-RL | Action Server] ROS services initialized.")
self.agent = self._initialize_agent()
rospy.loginfo("[Rosnav-RL | Action Server] Agent initialized.")
self.observation_collector = self._initialize_observation_collector()
rospy.loginfo("[Rosnav-RL | Action Server] Observation collector initialized.")
rospy.loginfo("[Rosnav-RL | Action Server] Spinning...")
while not rospy.is_shutdown():
rospy.spin()