Skip to content

Commit 807edfa

Browse files
authored
feat: add disco103 - meta learned update rule (#186)
1 parent fe9de0a commit 807edfa

11 files changed

Lines changed: 1157 additions & 41 deletions

File tree

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ dependencies = [
158158
"tqdm>=4.67.1",
159159
"wandb>=0.19.8",
160160
"playground>=0.0.5",
161-
"protobuf==3.20.3"
161+
"protobuf==3.20.3",
162162
]
163163

164164
[dependency-groups]
@@ -171,6 +171,11 @@ dev = [
171171
"testfixtures",
172172
]
173173

174+
[project.optional-dependencies]
175+
disco = [
176+
"disco_rl @ git+https://github.com/google-deepmind/disco_rl.git@main ; python_version >= '3.11'",
177+
]
178+
174179
[project.urls]
175180
"Homepage" = "https://github.com/EdanToledo/Stoix"
176181
"Bug Tracker" = "https://github.com/EdanToledo/Stoix/issues"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
defaults:
2+
- logger: logger
3+
- arch: anakin
4+
- system: disco_rl/ff_disco103
5+
- network: specialised/disco_rl
6+
- env: gymnax/cartpole
7+
- _self_
8+
9+
hydra:
10+
searchpath:
11+
- file://stoix/configs
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# ---MLP PPO Networks---
2+
agent_network:
3+
shared_torso:
4+
_target_: stoix.networks.torso.MLPTorso
5+
layer_sizes: [512, 512]
6+
use_layer_norm: False
7+
activation: relu
8+
9+
action_conditional_torso:
10+
_target_: stoix.networks.specialised.disco103.LSTMActionConditionedTorso
11+
lstm_size: 128
12+
activation: relu
13+
14+
logits_head:
15+
_target_: stoix.networks.base.chained_torsos
16+
_recursive_: false
17+
torso_cfgs:
18+
- _target_: stoix.networks.torso.MLPTorso
19+
layer_sizes: [128]
20+
use_layer_norm: False
21+
activation: relu
22+
- _target_: stoix.networks.heads.LinearHead
23+
24+
y_head:
25+
_target_: stoix.networks.base.chained_torsos
26+
_recursive_: false
27+
torso_cfgs:
28+
- _target_: stoix.networks.torso.MLPTorso
29+
layer_sizes: [128]
30+
use_layer_norm: False
31+
activation: relu
32+
- _target_: stoix.networks.heads.LinearHead
33+
34+
z_head:
35+
_target_: stoix.networks.base.chained_torsos
36+
_recursive_: false
37+
torso_cfgs:
38+
- _target_: stoix.networks.torso.MLPTorso
39+
layer_sizes: [128]
40+
use_layer_norm: False
41+
activation: relu
42+
- _target_: stoix.networks.heads.LinearHead
43+
44+
q_head:
45+
_target_: stoix.networks.base.chained_torsos
46+
_recursive_: false
47+
torso_cfgs:
48+
- _target_: stoix.networks.torso.MLPTorso
49+
layer_sizes: [128]
50+
use_layer_norm: False
51+
activation: relu
52+
- _target_: stoix.networks.heads.LinearHead
53+
54+
aux_pi_head:
55+
_target_: stoix.networks.base.chained_torsos
56+
_recursive_: false
57+
torso_cfgs:
58+
- _target_: stoix.networks.torso.MLPTorso
59+
layer_sizes: [128]
60+
use_layer_norm: False
61+
activation: relu
62+
- _target_: stoix.networks.heads.LinearHead
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# --- Defaults FF-Disco103 ---
2+
3+
system_name: ff_disco103 # Name of the system.
4+
5+
# --- RL hyperparameters ---
6+
rollout_length: 128 # Number of environment steps per training step.
7+
epochs: 4 # Number of epochs to train on the collected data.
8+
num_minibatches: 8 # Number of minibatches to split the data into.
9+
gamma: 0.997 # Discounting factor.
10+
lr: 3e-4 # Learning rate.
11+
max_abs_update: 1.0 # Maximum abs values for a weight update.
12+
reward_scale: 1.0 # Scaling factor for the reward.
13+
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
14+
15+
# DiscoRL HyperParams - These are passed to the agent_loss functions
16+
disco_hyperparams:
17+
pi_cost: 1.0 # Weight for the policy loss.
18+
y_cost: 1.0 # Weight for the 'y' auxiliary loss.
19+
z_cost: 1.0 # Weight for the 'z' auxiliary loss.
20+
aux_policy_cost: 1.0 # Weight for the auxiliary policy loss.
21+
value_cost: 0.2 # Weight for the value function loss.
22+
value_fn_td_lambda: 0.95 # Lambda for TD(lambda) updates of the value function.
23+
target_params_coeff: 0.9 # Polyak averaging coeff for target net
24+
25+
# DiscoRL UpdateRule Config - These values are passed to the DiscoUpdateRule constructor
26+
disco_rule:
27+
value_discount: ${system.gamma} # Discount factor for the value function.
28+
max_abs_value: 300.0 # Maximum absolute value for categorical value transforms.
29+
num_bins: 601 # Number of bins for the Q-network's categorical distribution.
30+
moving_average_decay: 0.99 # Decay rate for the moving average of statistics.
31+
# Config for the meta-network (meta_nets.LSTM)
32+
net:
33+
name: "lstm" # Name of the network architecture.
34+
prediction_size: 600 # Size of y and z auxiliary outputs
35+
hidden_size: 256 # Size of the hidden layers in the network.
36+
embedding_size: [16, 1] # Size of the embeddings.
37+
policy_channels: [16, 2] # Number of channels in the policy network layers.
38+
policy_target_channels: [16] # Number of channels in the policy target network layers.
39+
output_stddev: 0.3 # Standard deviation for the output distribution.
40+
aux_stddev: 0.3 # Standard deviation for the auxiliary output distributions.
41+
policy_target_stddev: 0.3 # Standard deviation for the policy target distribution.
42+
state_stddev: 1.0 # Standard deviation for the state normalization.
43+
# Config for the inner MetaLSTM
44+
meta_rnn_kwargs:
45+
hidden_size: 128 # Size of the hidden layers in the meta-network.
46+
embedding_size: [16] # Size of the embeddings in the meta-network.
47+
pred_embedding_size: [16, 1] # Size of the prediction embeddings in the meta-network.
48+
policy_channels: [16, 2] # Number of channels in the policy layers of the meta-network.

stoix/networks/base.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import inspect
23
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
34

45
import chex
@@ -221,14 +222,31 @@ def __call__(
221222
return critic_hidden_state, critic_output
222223

223224

224-
def chained_torsos(torso_cfgs: List[Dict[str, Any]]) -> nn.Module:
225+
def chained_torsos(torso_cfgs: List[Dict[str, Any]], **kwargs: Dict[str, Any]) -> nn.Module:
225226
"""Create a network by chaining multiple torsos together using a list of configs.
226227
This makes use of hydra to instantiate the modules and the composite network
227-
to chain them together.
228+
to chain them together. Be careful when using kwargs, if two torsos accept
229+
the same argument name, the value will be passed to both torsos.
228230
229231
Args:
230232
torso_cfgs: List of dictionaries containing the configuration for each torso.
231-
These configs should use the same format as the individual torso configs."""
233+
These configs should use the same format as the individual torso configs.
234+
**kwargs: Additional keyword arguments to pass to each torso during instantiation.
235+
"""
236+
237+
torso_modules = []
238+
for torso_cfg in torso_cfgs:
239+
# Get the target class
240+
target_class = hydra.utils.get_class(torso_cfg["_target_"])
241+
242+
# Inspect the signature to find all accepted parameter names
243+
sig = inspect.signature(target_class)
244+
accepted_keys = set(sig.parameters.keys())
245+
246+
# Filter the kwargs based on the accepted keys
247+
current_kwargs = {k: v for k, v in kwargs.items() if k in accepted_keys}
248+
249+
# Instantiate with the filtered kwargs
250+
torso_modules.append(hydra.utils.instantiate(torso_cfg, **current_kwargs))
232251

233-
torso_modules = [hydra.utils.instantiate(torso_cfg) for torso_cfg in torso_cfgs]
234252
return CompositeNetwork(torso_modules)

stoix/networks/heads.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,21 @@ def __call__(self, embedding: chex.Array) -> Tuple[distrax.EpsilonGreedy, chex.A
294294
class LinearHead(nn.Module):
295295
output_dim: int
296296
kernel_init: Initializer = orthogonal(0.01)
297+
pre_shape: Optional[Tuple[int, ...]] = None
298+
299+
def setup(self) -> None:
300+
if self.pre_shape is not None:
301+
self.shape = self.pre_shape + (self.output_dim,)
302+
else:
303+
self.shape = (self.output_dim,)
304+
self.output_size = int(np.prod(self.shape))
297305

298306
@nn.compact
299307
def __call__(self, embedding: chex.Array) -> chex.Array:
300-
301-
return nn.Dense(self.output_dim, kernel_init=self.kernel_init)(embedding)
308+
out = nn.Dense(self.output_size, kernel_init=self.kernel_init)(embedding)
309+
if self.pre_shape is None:
310+
return out
311+
return out.reshape(out.shape[:-1] + self.shape)
302312

303313

304314
class MultiDiscreteHead(nn.Module):
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from typing import Tuple
2+
3+
import chex
4+
import flax.linen as nn
5+
import jax
6+
import jax.numpy as jnp
7+
from flax.linen.initializers import Initializer, orthogonal
8+
9+
from stoix.networks.utils import parse_activation_fn
10+
from stoix.systems.disco_rl.disco_rl_types import AgentOutput
11+
12+
13+
class LSTMActionConditionedTorso(nn.Module):
14+
"""LSTM-based action-conditional torso inspired by Muesli/MuZero.
15+
16+
This torso creates a root embedding from the observation, then performs
17+
an LSTM transition for all possible actions in parallel, producing
18+
action-conditional hidden states of shape [batch, num_actions, hidden_dim].
19+
20+
Attributes:
21+
num_actions: Number of discrete actions.
22+
lstm_size: Size of the LSTM hidden state.
23+
root_mlp_sizes: Sizes of MLP layers for root embedding. If None, uses a single linear layer.
24+
activation: Activation function for the root MLP.
25+
kernel_init: Kernel initializer for linear layers.
26+
"""
27+
28+
num_actions: int
29+
lstm_size: int
30+
root_mlp_sizes: Tuple[int, ...] = ()
31+
activation: str = "relu"
32+
kernel_init: Initializer = orthogonal(1.0)
33+
34+
@nn.compact
35+
def __call__(self, observation: chex.Array) -> chex.Array:
36+
"""Forward pass.
37+
38+
Args:
39+
observation: Input observation of shape [batch, ...].
40+
41+
Returns:
42+
Action-conditional hidden states of shape [batch, num_actions, lstm_size].
43+
"""
44+
batch_size = observation.shape[0]
45+
46+
# 1. Create root embedding from observation
47+
root_embedding = self._root_embedding(observation) # [batch, lstm_size]
48+
49+
# 2. Perform LSTM transition for all actions
50+
action_hidden_states = self._model_transition_all_actions(
51+
root_embedding, batch_size
52+
) # [batch, num_actions, lstm_size]
53+
54+
return action_hidden_states
55+
56+
def _root_embedding(self, observation: chex.Array) -> chex.Array:
57+
"""Constructs a root embedding from the observation.
58+
59+
Args:
60+
observation: Input observation of shape [batch, ...].
61+
62+
Returns:
63+
Root embedding (LSTM cell state) of shape [batch, lstm_size].
64+
"""
65+
# Simply use the observation as input
66+
x = observation
67+
68+
# Apply optional MLP layers
69+
if self.root_mlp_sizes:
70+
for size in self.root_mlp_sizes:
71+
x = nn.Dense(size, kernel_init=self.kernel_init)(x)
72+
x = parse_activation_fn(self.activation)(x)
73+
74+
# Final linear layer to get cell state
75+
cell = nn.Dense(self.lstm_size, kernel_init=self.kernel_init, name="root_cell")(x)
76+
# Create hidden state as tanh(cell)
77+
hidden = jnp.tanh(cell)
78+
return (hidden, cell)
79+
80+
def _model_transition_all_actions(self, root_carry: chex.Array, batch_size: int) -> chex.Array:
81+
"""Performs LSTM transition for all actions in parallel.
82+
83+
Args:
84+
root_carry: Root carry state of shape [batch, lstm_size].
85+
batch_size: Batch size.
86+
87+
Returns:
88+
LSTM outputs for all actions of shape [batch, num_actions, lstm_size].
89+
"""
90+
# Create one-hot encodings for all actions
91+
# Shape: [num_actions, num_actions]
92+
one_hot_actions = jnp.eye(self.num_actions, dtype=root_carry[0].dtype)
93+
94+
# Repeat for each batch element
95+
# Shape: [batch * num_actions, num_actions]
96+
batched_one_hot_actions = jnp.tile(one_hot_actions, [batch_size, 1])
97+
98+
# Repeat the root carry for each action
99+
# This uses jax.tree.map to handle the (hidden, cell) tuple
100+
initial_carry = jax.tree.map(
101+
lambda x: jnp.repeat(x, repeats=self.num_actions, axis=0), root_carry
102+
)
103+
104+
# Apply LSTM
105+
lstm_cell = nn.LSTMCell(features=self.lstm_size, name="action_cond_lstm")
106+
_, lstm_output = lstm_cell(initial_carry, batched_one_hot_actions)
107+
108+
# Reshape output from [batch * num_actions, lstm_size] to [batch, num_actions, lstm_size]
109+
action_hidden_states = lstm_output.reshape(batch_size, self.num_actions, self.lstm_size)
110+
111+
return action_hidden_states
112+
113+
114+
class DiscoAgentNetwork(nn.Module):
115+
"""
116+
A network for the DiscoRL agent.
117+
118+
This network has a shared torso and five separate heads, matching
119+
the architecture required by the DiscoUpdateRule:
120+
1. logits (Policy)
121+
2. q (Categorical Value)
122+
3. y (Auxiliary)
123+
4. z (Auxiliary)
124+
5. aux_pi (Auxiliary Policy)
125+
"""
126+
127+
shared_torso: nn.Module
128+
action_conditional_torso: nn.Module
129+
logits_head: nn.Module
130+
q_head: nn.Module
131+
y_head: nn.Module
132+
z_head: nn.Module
133+
aux_pi_head: nn.Module
134+
135+
def __call__(self, obs: chex.Array) -> AgentOutput:
136+
"""Forward pass."""
137+
# Run the shared torso
138+
torso_output = self.shared_torso(obs)
139+
140+
# Run logits and y prediction heads on the torso output
141+
logits = self.logits_head(torso_output)
142+
y = self.y_head(torso_output)
143+
144+
# We now run the action conditional heads.
145+
# We do this by running an action-conditional torso first,
146+
# then passing its output to the q, z, and aux_pi heads.
147+
action_conditional_torso_output = self.action_conditional_torso(torso_output)
148+
q = self.q_head(action_conditional_torso_output)
149+
z = self.z_head(action_conditional_torso_output)
150+
aux_pi = self.aux_pi_head(action_conditional_torso_output)
151+
152+
return AgentOutput(logits=logits, q=q, y=y, z=z, aux_pi=aux_pi)

0 commit comments

Comments
 (0)