Skip to content

Improve memory efficiency of relative positional encoding #3

@cswinter

Description

@cswinter

Our relative positional encoding implementation currently materializes the full sequence x sequence set of positional keys/values:
https://github.com/entity-neural-network/incubator/blob/eb62f6fe2c27c7f852c6ffbb40cc99422ded122f/rogue_net/rogue_net/relpos_encoding.py#L101-L102
This requires a lot of memory, e.g. for batch size of 8192, sequence length of 32, head dimension 64, this is 2 x 8192 x 32 x 32 x 64 x 4B = 4.3GB. We could fix this with two custom kernels that fuse the lookup of relative positional keys with the dot product with the queries, and the lookup of the relative positional values with the attention-weighted sum.

https://github.com/entity-neural-network/incubator/blob/eb62f6fe2c27c7f852c6ffbb40cc99422ded122f/rogue_net/rogue_net/transformer.py#L135-L142

https://github.com/entity-neural-network/incubator/blob/eb62f6fe2c27c7f852c6ffbb40cc99422ded122f/rogue_net/rogue_net/transformer.py#L151-L154

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions