Skip to content

Replace torch-scatter with custom ragged reduce operation in Triton #1

@cswinter

Description

@cswinter

The torch-scatter dependency is a huge PITA.
We currently use it for reducing across sequences in a ragged batch to get inputs to the value head:

https://github.com/entity-neural-network/incubator/blob/85cd666f3401ca0d9eebfd0b6603e14de2311b4a/rogue_net/rogue_net/actor.py#L110-L112

This is not a very complex operation, we should just replace it with a custom Triton op and yeet the torch-scatter dependency.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions