The torch-scatter dependency is a huge PITA.
We currently use it for reducing across sequences in a ragged batch to get inputs to the value head:
https://github.com/entity-neural-network/incubator/blob/85cd666f3401ca0d9eebfd0b6603e14de2311b4a/rogue_net/rogue_net/actor.py#L110-L112
This is not a very complex operation, we should just replace it with a custom Triton op and yeet the torch-scatter dependency.