|
| 1 | +from typing import Tuple |
| 2 | + |
| 3 | +import chex |
| 4 | +import flax.linen as nn |
| 5 | +import jax |
| 6 | +import jax.numpy as jnp |
| 7 | +from flax.linen.initializers import Initializer, orthogonal |
| 8 | + |
| 9 | +from stoix.networks.utils import parse_activation_fn |
| 10 | +from stoix.systems.disco_rl.disco_rl_types import AgentOutput |
| 11 | + |
| 12 | + |
| 13 | +class LSTMActionConditionedTorso(nn.Module): |
| 14 | + """LSTM-based action-conditional torso inspired by Muesli/MuZero. |
| 15 | +
|
| 16 | + This torso creates a root embedding from the observation, then performs |
| 17 | + an LSTM transition for all possible actions in parallel, producing |
| 18 | + action-conditional hidden states of shape [batch, num_actions, hidden_dim]. |
| 19 | +
|
| 20 | + Attributes: |
| 21 | + num_actions: Number of discrete actions. |
| 22 | + lstm_size: Size of the LSTM hidden state. |
| 23 | + root_mlp_sizes: Sizes of MLP layers for root embedding. If None, uses a single linear layer. |
| 24 | + activation: Activation function for the root MLP. |
| 25 | + kernel_init: Kernel initializer for linear layers. |
| 26 | + """ |
| 27 | + |
| 28 | + num_actions: int |
| 29 | + lstm_size: int |
| 30 | + root_mlp_sizes: Tuple[int, ...] = () |
| 31 | + activation: str = "relu" |
| 32 | + kernel_init: Initializer = orthogonal(1.0) |
| 33 | + |
| 34 | + @nn.compact |
| 35 | + def __call__(self, observation: chex.Array) -> chex.Array: |
| 36 | + """Forward pass. |
| 37 | +
|
| 38 | + Args: |
| 39 | + observation: Input observation of shape [batch, ...]. |
| 40 | +
|
| 41 | + Returns: |
| 42 | + Action-conditional hidden states of shape [batch, num_actions, lstm_size]. |
| 43 | + """ |
| 44 | + batch_size = observation.shape[0] |
| 45 | + |
| 46 | + # 1. Create root embedding from observation |
| 47 | + root_embedding = self._root_embedding(observation) # [batch, lstm_size] |
| 48 | + |
| 49 | + # 2. Perform LSTM transition for all actions |
| 50 | + action_hidden_states = self._model_transition_all_actions( |
| 51 | + root_embedding, batch_size |
| 52 | + ) # [batch, num_actions, lstm_size] |
| 53 | + |
| 54 | + return action_hidden_states |
| 55 | + |
| 56 | + def _root_embedding(self, observation: chex.Array) -> chex.Array: |
| 57 | + """Constructs a root embedding from the observation. |
| 58 | +
|
| 59 | + Args: |
| 60 | + observation: Input observation of shape [batch, ...]. |
| 61 | +
|
| 62 | + Returns: |
| 63 | + Root embedding (LSTM cell state) of shape [batch, lstm_size]. |
| 64 | + """ |
| 65 | + # Simply use the observation as input |
| 66 | + x = observation |
| 67 | + |
| 68 | + # Apply optional MLP layers |
| 69 | + if self.root_mlp_sizes: |
| 70 | + for size in self.root_mlp_sizes: |
| 71 | + x = nn.Dense(size, kernel_init=self.kernel_init)(x) |
| 72 | + x = parse_activation_fn(self.activation)(x) |
| 73 | + |
| 74 | + # Final linear layer to get cell state |
| 75 | + cell = nn.Dense(self.lstm_size, kernel_init=self.kernel_init, name="root_cell")(x) |
| 76 | + # Create hidden state as tanh(cell) |
| 77 | + hidden = jnp.tanh(cell) |
| 78 | + return (hidden, cell) |
| 79 | + |
| 80 | + def _model_transition_all_actions(self, root_carry: chex.Array, batch_size: int) -> chex.Array: |
| 81 | + """Performs LSTM transition for all actions in parallel. |
| 82 | +
|
| 83 | + Args: |
| 84 | + root_carry: Root carry state of shape [batch, lstm_size]. |
| 85 | + batch_size: Batch size. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + LSTM outputs for all actions of shape [batch, num_actions, lstm_size]. |
| 89 | + """ |
| 90 | + # Create one-hot encodings for all actions |
| 91 | + # Shape: [num_actions, num_actions] |
| 92 | + one_hot_actions = jnp.eye(self.num_actions, dtype=root_carry[0].dtype) |
| 93 | + |
| 94 | + # Repeat for each batch element |
| 95 | + # Shape: [batch * num_actions, num_actions] |
| 96 | + batched_one_hot_actions = jnp.tile(one_hot_actions, [batch_size, 1]) |
| 97 | + |
| 98 | + # Repeat the root carry for each action |
| 99 | + # This uses jax.tree.map to handle the (hidden, cell) tuple |
| 100 | + initial_carry = jax.tree.map( |
| 101 | + lambda x: jnp.repeat(x, repeats=self.num_actions, axis=0), root_carry |
| 102 | + ) |
| 103 | + |
| 104 | + # Apply LSTM |
| 105 | + lstm_cell = nn.LSTMCell(features=self.lstm_size, name="action_cond_lstm") |
| 106 | + _, lstm_output = lstm_cell(initial_carry, batched_one_hot_actions) |
| 107 | + |
| 108 | + # Reshape output from [batch * num_actions, lstm_size] to [batch, num_actions, lstm_size] |
| 109 | + action_hidden_states = lstm_output.reshape(batch_size, self.num_actions, self.lstm_size) |
| 110 | + |
| 111 | + return action_hidden_states |
| 112 | + |
| 113 | + |
| 114 | +class DiscoAgentNetwork(nn.Module): |
| 115 | + """ |
| 116 | + A network for the DiscoRL agent. |
| 117 | +
|
| 118 | + This network has a shared torso and five separate heads, matching |
| 119 | + the architecture required by the DiscoUpdateRule: |
| 120 | + 1. logits (Policy) |
| 121 | + 2. q (Categorical Value) |
| 122 | + 3. y (Auxiliary) |
| 123 | + 4. z (Auxiliary) |
| 124 | + 5. aux_pi (Auxiliary Policy) |
| 125 | + """ |
| 126 | + |
| 127 | + shared_torso: nn.Module |
| 128 | + action_conditional_torso: nn.Module |
| 129 | + logits_head: nn.Module |
| 130 | + q_head: nn.Module |
| 131 | + y_head: nn.Module |
| 132 | + z_head: nn.Module |
| 133 | + aux_pi_head: nn.Module |
| 134 | + |
| 135 | + def __call__(self, obs: chex.Array) -> AgentOutput: |
| 136 | + """Forward pass.""" |
| 137 | + # Run the shared torso |
| 138 | + torso_output = self.shared_torso(obs) |
| 139 | + |
| 140 | + # Run logits and y prediction heads on the torso output |
| 141 | + logits = self.logits_head(torso_output) |
| 142 | + y = self.y_head(torso_output) |
| 143 | + |
| 144 | + # We now run the action conditional heads. |
| 145 | + # We do this by running an action-conditional torso first, |
| 146 | + # then passing its output to the q, z, and aux_pi heads. |
| 147 | + action_conditional_torso_output = self.action_conditional_torso(torso_output) |
| 148 | + q = self.q_head(action_conditional_torso_output) |
| 149 | + z = self.z_head(action_conditional_torso_output) |
| 150 | + aux_pi = self.aux_pi_head(action_conditional_torso_output) |
| 151 | + |
| 152 | + return AgentOutput(logits=logits, q=q, y=y, z=z, aux_pi=aux_pi) |
0 commit comments