1313class ContinuousQVAC (nn .Module ):
1414 """
1515 Overview:
16- The neural network and computation graph of algorithms related to Actor-Critic that have both Q-value and V-value critic, such as \
17- IQL. This model now supports continuous and hybrid action space. The ContinuousQVAC is composed of \
18- four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \
19- extract the feature from various observation . Heads are used to predict corresponding Q- value and V-value or action logit. \
16+ The neural network and computation graph of algorithms related to Actor-Critic that have both Q-value and \
17+ V-value critic, such as IQL. This model now supports continuous and hybrid action space. The ContinuousQVAC is \
18+ composed of four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders \
19+ are used to extract the feature . Heads are used to predict corresponding value or action logit.
2020 In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \
2121 and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders.
2222 Interfaces:
@@ -34,7 +34,7 @@ def __init__(
3434 actor_head_layer_num : int = 1 ,
3535 critic_head_hidden_size : int = 64 ,
3636 critic_head_layer_num : int = 1 ,
37- activation : Optional [nn .Module ] = nn .SiLU (), #nn.ReLU(),
37+ activation : Optional [nn .Module ] = nn .SiLU (),
3838 norm_type : Optional [str ] = None ,
3939 encoder_hidden_size_list : Optional [SequenceType ] = None ,
4040 share_encoder : Optional [bool ] = False ,
@@ -319,7 +319,7 @@ def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
319319 - logit (:obj:`torch.Tensor`): Discrete action logit, only in hybrid action_space.
320320 - action_args (:obj:`torch.Tensor`): Continuous action arguments, only in hybrid action_space.
321321 Returns:
322- - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC's forward computation graph for critic, \
322+ - outputs (:obj:`Dict[str, torch.Tensor]`): The output of QVAC's forward computation graph for critic, \
323323 including ``q_value``.
324324 ReturnKeys:
325325 - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
0 commit comments