-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds an `EmbeddingFwdOp` with same functionality as `F.embedding`. 1. I am not using `take_along_axis`. `F.embedding` allows optional parameters like `max_norm, padding_idx` which would require further processing if implemented using `take_along_axis`. So I defaulted to creating a new node to guarantee performance parity. 2. Thunder uses `prims.EMBEDDING` if the optional parameters `padding_idx/max_norm` are specified, else it uses `prims.TAKE`. This prevents nvfuser from consuming embedding operator in the other cases. Hence, in Thunder, nvfuser will also directly execute `ltorch.embedding`. This will require a separate backward API to consume `ltorch.embedding_backward` and cannot reuse grad rules for `prims.EMBEDDING`. Hence, the `EmbeddingFwdOp` naming instead of `EmbeddingOp`. 3. I first plan to plumb the fwd only embedding support in Thunder while I draft the backward node which should be very similar. Thunder reviews may bring up another way of implementing this support.
- Loading branch information
Showing
16 changed files
with
511 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.