-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Labels
Description
Our relative positional encoding implementation currently materializes the full sequence x sequence set of positional keys/values:
https://github.com/entity-neural-network/incubator/blob/eb62f6fe2c27c7f852c6ffbb40cc99422ded122f/rogue_net/rogue_net/relpos_encoding.py#L101-L102
This requires a lot of memory, e.g. for batch size of 8192, sequence length of 32, head dimension 64, this is 2 x 8192 x 32 x 32 x 64 x 4B = 4.3GB. We could fix this with two custom kernels that fuse the lookup of relative positional keys with the dot product with the queries, and the lookup of the relative positional values with the attention-weighted sum.