Skip to content

Commit

Permalink
documentation fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Dec 5, 2024
1 parent 2acb92f commit b07bb2d
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,6 @@ def infer_sharding_from_operands(
)

# Modify operand specs:
# - FSDP axes are all-gathered
# - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded
# - LHS operand contracting dimension sharding is forced to match RHS contracting dimension
lhs_spec_new = [spec for spec in lhs_spec]
Expand Down Expand Up @@ -584,7 +583,6 @@ def partition(
)

# Modify operand specs:
# - FSDP axes are all-gathered
# - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded
# - LHS operand contracting dimension sharding is forced to match RHS contracting dimension
lhs_spec_new = [spec for spec in lhs_spec]
Expand Down

0 comments on commit b07bb2d

Please sign in to comment.