The rosnav-rl
framework provides a flexible interface for implementing custom reinforcement learning policies using Stable-Baselines3. This guide explains how to create your own policy implementations using the StableBaselinesPolicyDescription
base class.
The StableBaselinesPolicyDescription
class serves as an abstract base class for defining custom policies. It provides a standardized interface that integrates seamlessly with Stable-Baselines3 algorithms while adding specific functionality for robotic navigation tasks.
To create a custom policy, extend the StableBaselinesPolicyDescription
class and implement all required abstract properties:
@AgentFactory.register("CustomAgent")
class CustomAgent(StableBaselinesPolicyDescription):
algorithm_class = PPO # or any other SB3 algorithm
observation_spaces = [
spaces.StackedLaserMapSpace,
spaces.DistAngleToSubgoalSpace,
spaces.LastActionSpace,
]
observation_space_kwargs = {
"roi_in_m": 20,
"feature_map_size": 80,
"normalize": True,
}
features_extractor_class = CustomFeatureExtractor
features_extractor_kwargs = {
"features_dim": 256,
}
net_arch = dict(pi=[64, 64], vf=[64, 64])
activation_fn = nn.ReLU
Specify which Stable-Baselines3 algorithm to use:
algorithm_class: Type[BaseAlgorithm] = PPO
Define which observation spaces your agent requires:
observation_spaces = [
spaces.StackedLaserMapSpace,
spaces.DistAngleToSubgoalSpace,
# Add more observation spaces as needed
]
Configure the policy and value function networks:
net_arch = dict(
pi=[64, 64], # Policy network architecture
vf=[64, 64] # Value function network architecture
)
Specify how raw observations should be processed:
features_extractor_class = CustomFeatureExtractor
features_extractor_kwargs = {
"features_dim": 256,
# Additional extractor parameters
}
Customize observation space parameters:
observation_space_kwargs = {
"roi_in_m": 20,
"feature_map_size": 80,
"normalize": True,
}
Define how many observations should be stacked:
@property
def stack_size(self) -> int:
return 4 # Default is 1
Here's a more sophisticated example using RecurrentPPO:
@AgentFactory.register("RecurrentNavigationAgent")
class RecurrentNavigationAgent(StableBaselinesPolicyDescription):
algorithm_class = RecurrentPPO
observation_spaces = [
spaces.StackedLaserMapSpace,
spaces.PedestrianVelXSpace,
spaces.PedestrianVelYSpace,
spaces.DistAngleToSubgoalSpace,
]
observation_space_kwargs = {
"roi_in_m": 20,
"feature_map_size": 80,
"laser_stack_size": 10,
"normalize": True,
"goal_max_dist": 10,
}
features_extractor_class = CustomRecurrentExtractor
features_extractor_kwargs = {
"features_dim": 512,
"lstm_hidden_size": 128,
"n_lstm_layers": 2,
}
net_arch = dict(pi=[256, 128], vf=[256, 64])
activation_fn = nn.GELU
Once implemented, your custom policy can be used in the Arena-Rosnav framework by:
- Registering it with the
@AgentFactory.register
decorator - Specifying it in your training configuration
- The framework will automatically handle the integration with Stable-Baselines3
Your policy will then be ready for training and evaluation in various navigation scenarios within the Arena-Rosnav environment.