Code Appendix for the Paper "Real-Time Recurrent Reinforcement Learning" accepted for AAAI 2025.
- Install Poetry: https://python-poetry.org/docs/#installation
- Install dependencies using
poetry install
- Launch poetry shell:
poetry shell
python rtrrl.py
You can log results using aim
or wandb
. By default no logging framework is installed.
After installing the respective package, you can enable logging by providing the --logging
argument.
pip install aim
python rtrrl.py --logging aim
A GPU can speed up training when using large batch sizes but will slow it down for smaller ones.
Make sure to install the CUDA version of jax
and jaxlib
.
pip install jax[cuda]
Symbol | Description | Default Value |
---|---|---|
Discount factor. | 0.99 | |
TD( |
1e-5 | |
RNN learning rate. | 1e-5 | |
Entropy rate. | 1e-5 | |
Actor trace scaling. | 1.0 | |
Lambda for actor eligibility trace. | 0.99 | |
Lambda for critic eligibility trace. | 0.99 | |
Lambda for RNN eligibility trace. | 0.99 |
This is an incomplete table of configurables.
Run poetry run python rtrrl.py --help
to find out more.
There is a preset for brax
environments that can be used by providing the config path:
python rtrrl.py --config_path configs/brax.yml
Name | Description | Default Value |
---|---|---|
debug | Enables debugging functionality. | False |
env_name | Environment ID as defined by gymnax
|
'CartPole-v1' |
obs_mask | Allows masking of observation. Allowed values are None, 'even', 'odd', 'first_half' or a List of indices. | None |
env_init_args | Arguments passed to environment constructor (e.g. size=16 for DeepSea-bsuite ) |
- |
env_params | Environment parameters passed to step and reset methods. (e.g. memory_length=32 for MemoryChain-bsuite ) |
- |
rnn_model | Determines which RNN model is used. Set to None for vanilla TD( |
'CTRNN_simple' |
hidden_size | RNN hidden state size. | 16 |
seed | Random seed for jax. Set to None for a random integer. |
None |
optimizer_params_td.opt_name | Optimizer used for tD( |
'adam' |
optimizer_params_rnn.opt_name | Optimizer used for the RNN. | 'adam' |
episodes | Total training episodes. | 150_000 |
eval_every | Number of episodes between evaluation. | 100 |
eval_steps | Number of evaluation steps. | 10000 |
steps | Number of training steps per episode. | 10000 |
max_ep_length | Max number of steps in episode. Specific environments may supersede this. | 1000 |
patience | Early stopping is triggered after this number of evaluation episodes without improvement. | 20 |
batch_size | Number of parallel environments. | 1 |
eta | Can be used for infinite horizon tasks. If set, average reward |
None |
eta_pi | Scale gradients of action probability passed to RNN. | 1 |
eta_f | Scale gradients of RNN. | 1 |
entropy_rate | Scale gradient of action entropy. | 1 |
var_scaling | If True, scales the gradients of action probability by the scale of the action distribution. Only works for continuous actions. | False |
gradient_mode | Select method for online gradient computation: 'RTRL', 'RFLO' or 'LocalMSE'. Ignored when LRU is used for rnn_model . |
'RFLO' |
trace_mode | Type of eligibility trace. 'accumulate' or 'dutch' | 'accumulate' |
wiring | Specify wiring of RNN. See modles/jax/wirings_jax.py for available options. |
'fully_connected' |
dt | Determines number of steps for forward Euler. e.g. 0.2 results in 5 steps. | 1 |