This repo is meant as a starting point for RL research based off of the PQN algo from the "Simplifying Deep Temporal Difference Learning" paper.
- clean, flat implementation of PQN (original code is very deeply nested and hard to work with)
- Extremely efficient, scalable parallel training using
shard_map(tested with 64 H100 gpus) - Parallel logging of multiple seeds at once to Weights & Biases
As part of my work at Mila's IDT team, I use this code to help test out new HPC clusters. Specifically, I use the opportunity of having an entire cluster free to see how far one can scale the PQN algo.
On the TamIA cluster, this was used to train PQN on close to 200 000 Craftax environments in parallel, using 64 H100 gpus (16-nodes x 4 H100 gpus per node). Interestingly, PQN is unable to learn a good policy at that scale using the same hyper-parameters.
If thats interesting to you, keep reading! :) I hope this can help you try out new ideas.
Here are some potentially interesting research directions or improvements:
-
Try using the "Stop Regressing" cross entropy loss instead of MSE for the value network. Maybe that loss also enables training with larger batch sizes for the same network capacity (in addition to training with larger networks)?
-
Try scaling the learning rate by some factor of the effective batch size, or increasing the number of updates per batch, see if that enables PQN to scale?
-
Add a network config group (and refactor code a bit as needed) to make it possible to easily try out other networks than the RNN (e.g. a Transformer of some sort). Reuse networks and take inspiration from Rejax.
-
Try disabling the "optimistic reset" wrapper. perhaps this will increase sample diversity and enable training with larger batch sizes?
- Why do the PQN authors also use that wrapper on the test environment? Investigate.
- There is some undefined behavior that is ignored (and doesnt seem to cause issues?) in the sampling of which envs to reset in that wrapper. Investigate.
-
The original PQN code passes the memory_transitions buffer of length
mem_window + num_stepsbetween each iteration, of which only the lastmem_windoware actually used inside the iteration. Fix this bug, and measure relative performance improvement in steps/sec. -
Add support for other environments, in particular Atari via EnvPool's xla adapter.
-
Enable vmapping over other hparams than just the random seed. Take inspiration from Rejax again here.
-
Implement checkpointing
-
Fix tests, add more of them
If any of this sounds interesting to you, feel free to make an issue or reach out via email, Id be happy to help.